diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 0caa8199a..ecc35f37a 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -116,6 +116,9 @@ const bulkSelectEventNIDSQL = "" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" +const selectRoomNIDForEventNIDSQL = "" + + "SELECT room_nid FROM roomserver_events WHERE event_nid = $1" + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -130,6 +133,7 @@ type eventStatements struct { bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt + selectRoomNIDForEventNIDStmt *sql.Stmt } func (s *eventStatements) prepare(db *sql.DB) (err error) { @@ -152,6 +156,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, + {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, }.prepare(db) } @@ -417,6 +422,14 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []t return result, nil } +func (s *eventStatements) selectRoomNIDForEventNID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (roomNID types.RoomNID, err error) { + selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) + err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) + return +} + func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { nids := make([]int64, len(eventNIDs)) for i := range eventNIDs { diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 3bed09477..9098b482f 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -254,12 +254,21 @@ func (d *Database) Events( } results := make([]types.Event, len(eventJSONs)) for i, eventJSON := range eventJSONs { + var roomNID types.RoomNID + var roomVersion gomatrixserverlib.RoomVersion result := &results[i] result.EventNID = eventJSON.EventNID // TODO: Use NewEventFromTrustedJSON for efficiency - // TODO: Room version here + roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) + if err != nil { + return nil, err + } + roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, nil, roomNID) + if err != nil { + return nil, err + } result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON( - eventJSON.EventJSON, gomatrixserverlib.RoomVersionV1, + eventJSON.EventJSON, roomVersion, ) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 1e4ed448f..d881fa91f 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -96,6 +96,9 @@ const bulkSelectEventNIDSQL = "" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" +const selectRoomNIDForEventNIDSQL = "" + + "SELECT room_nid FROM roomserver_events WHERE event_nid = $1" + type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt @@ -111,6 +114,7 @@ type eventStatements struct { bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt + selectRoomNIDForEventNIDStmt *sql.Stmt } func (s *eventStatements) prepare(db *sql.DB) (err error) { @@ -134,6 +138,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, }.prepare(db) } @@ -472,6 +477,14 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, return result, nil } +func (s *eventStatements) selectRoomNIDForEventNID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (roomNID types.RoomNID, err error) { + selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) + err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) + return +} + func eventNIDsAsArray(eventNIDs []types.EventNID) string { b, _ := json.Marshal(eventNIDs) return string(b) diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 70fd979d6..2bbc8d57f 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -307,12 +307,21 @@ func (d *Database) Events( } results = make([]types.Event, len(eventJSONs)) for i, eventJSON := range eventJSONs { + var roomNID types.RoomNID + var roomVersion gomatrixserverlib.RoomVersion result := &results[i] result.EventNID = eventJSON.EventNID // TODO: Use NewEventFromTrustedJSON for efficiency - // TODO: Room version here + roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventJSON.EventNID) + if err != nil { + return err + } + roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, txn, roomNID) + if err != nil { + return err + } result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON( - eventJSON.EventJSON, gomatrixserverlib.RoomVersionV1, + eventJSON.EventJSON, roomVersion, ) if err != nil { return nil