Retrieve room version where known in roomserver

This commit is contained in:
Neil Alexander 2020-03-17 15:53:23 +00:00
parent a66c701b29
commit bc8735ebc6
8 changed files with 87 additions and 14 deletions

View file

@ -72,6 +72,10 @@ type RoomEventDatabase interface {
ctx context.Context, transactionID string, ctx context.Context, transactionID string,
sessionID int64, userID string, sessionID int64, userID string,
) (string, error) ) (string, error)
// Look up the room version for a given room.
GetRoomVersionForRoom(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error)
} }
// OutputRoomEventWriter has the APIs needed to write an event to the output logs. // OutputRoomEventWriter has the APIs needed to write an event to the output logs.

View file

@ -253,8 +253,10 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
latestEventIDs[i] = u.latest[i].EventID latestEventIDs[i] = u.latest[i].EventID
} }
// TODO: Room version here roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.roomNID)
roomVersion := gomatrixserverlib.RoomVersionV1 if err != nil {
return nil, err
}
ore := api.OutputNewRoomEvent{ ore := api.OutputNewRoomEvent{
Event: u.event.Headered(roomVersion), Event: u.event.Headered(roomVersion),

View file

@ -93,6 +93,10 @@ type RoomserverQueryAPIDatabase interface {
GetRoomVersionForRoom( GetRoomVersionForRoom(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) ) (gomatrixserverlib.RoomVersion, error)
// Look up the room NID that an event ID appears in.
GetRoomNIDForEventID(
ctx context.Context, eventID string,
) (types.RoomNID, error)
} }
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
@ -234,8 +238,15 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
} }
for _, event := range events { for _, event := range events {
// TODO: Room version here roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID())
roomVersion := gomatrixserverlib.RoomVersionV1 if nerr != nil {
return nerr
}
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion)) response.Events = append(response.Events, event.Headered(roomVersion))
} }
@ -516,8 +527,15 @@ func (r *RoomserverQueryAPI) QueryMissingEvents(
response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
for _, event := range loadedEvents { for _, event := range loadedEvents {
if !eventsToFilter[event.EventID()] { if !eventsToFilter[event.EventID()] {
// TODO: Room version here roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID())
roomVersion := gomatrixserverlib.RoomVersionV1 if nerr != nil {
return nerr
}
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion)) response.Events = append(response.Events, event.Headered(roomVersion))
} }
@ -562,8 +580,15 @@ func (r *RoomserverQueryAPI) QueryBackfill(
} }
for _, event := range loadedEvents { for _, event := range loadedEvents {
// TODO: Room version here roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID())
roomVersion := gomatrixserverlib.RoomVersionV1 if nerr != nil {
return nerr
}
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion)) response.Events = append(response.Events, event.Headered(roomVersion))
} }
@ -653,6 +678,11 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain(
} }
response.RoomExists = true response.RoomExists = true
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if err != nil {
return err
}
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
@ -683,16 +713,10 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain(
} }
for _, event := range stateEvents { for _, event := range stateEvents {
// TODO: Room version here
roomVersion := gomatrixserverlib.RoomVersionV1
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
} }
for _, event := range authEvents { for _, event := range authEvents {
// TODO: Room version here
roomVersion := gomatrixserverlib.RoomVersionV1
response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(roomVersion)) response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(roomVersion))
} }

View file

@ -38,6 +38,7 @@ type Database interface {
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error)
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
GetRoomIDForAlias(ctx context.Context, alias string) (string, error) GetRoomIDForAlias(ctx context.Context, alias string) (string, error)
GetRoomNIDForEventID(ctx context.Context, eventID string) (types.RoomNID, error)
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) GetCreatorIDForAlias(ctx context.Context, alias string) (string, error)
RemoveRoomAlias(ctx context.Context, alias string) error RemoveRoomAlias(ctx context.Context, alias string) error

View file

@ -116,6 +116,9 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
const selectRoomNIDForEventIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_id = $1"
type eventStatements struct { type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt selectEventStmt *sql.Stmt
@ -130,6 +133,7 @@ type eventStatements struct {
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDForEventIDStmt *sql.Stmt
} }
func (s *eventStatements) prepare(db *sql.DB) (err error) { func (s *eventStatements) prepare(db *sql.DB) (err error) {
@ -152,6 +156,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL},
}.prepare(db) }.prepare(db)
} }
@ -424,3 +429,11 @@ func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
} }
return nids return nids
} }
func (s *eventStatements) selectRoomNIDForEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (roomNID types.RoomNID, err error) {
selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt)
err = selectStmt.QueryRowContext(ctx, eventID).Scan(&roomNID)
return
}

View file

@ -747,6 +747,14 @@ func (d *Database) GetRoomVersionForRoom(
) )
} }
func (d *Database) GetRoomNIDForEventID(
ctx context.Context, eventID string,
) (types.RoomNID, error) {
return d.statements.selectRoomNIDForEventID(
ctx, nil, eventID,
)
}
type transaction struct { type transaction struct {
ctx context.Context ctx context.Context
txn *sql.Tx txn *sql.Tx

View file

@ -96,6 +96,9 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
const selectRoomNIDForEventIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_id = $1"
type eventStatements struct { type eventStatements struct {
db *sql.DB db *sql.DB
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
@ -111,6 +114,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
selectRoomNIDForEventIDStmt *sql.Stmt
} }
func (s *eventStatements) prepare(db *sql.DB) (err error) { func (s *eventStatements) prepare(db *sql.DB) (err error) {
@ -134,6 +138,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL},
}.prepare(db) }.prepare(db)
} }
@ -476,3 +481,11 @@ func eventNIDsAsArray(eventNIDs []types.EventNID) string {
b, _ := json.Marshal(eventNIDs) b, _ := json.Marshal(eventNIDs)
return string(b) return string(b)
} }
func (s *eventStatements) selectRoomNIDForEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (roomNID types.RoomNID, err error) {
selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt)
err = selectStmt.QueryRowContext(ctx, eventID).Scan(&roomNID)
return
}

View file

@ -902,6 +902,14 @@ func (d *Database) GetRoomVersionForRoom(
) )
} }
func (d *Database) GetRoomNIDForEventID(
ctx context.Context, eventID string,
) (types.RoomNID, error) {
return d.statements.selectRoomNIDForEventID(
ctx, nil, eventID,
)
}
type transaction struct { type transaction struct {
ctx context.Context ctx context.Context
txn *sql.Tx txn *sql.Tx