From bc8735ebc63bc78287304493ebcf3b5778f45b98 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 17 Mar 2020 15:53:23 +0000 Subject: [PATCH] Retrieve room version where known in roomserver --- roomserver/input/events.go | 4 ++ roomserver/input/latest_events.go | 6 ++- roomserver/query/query.go | 48 +++++++++++++++------ roomserver/storage/interface.go | 1 + roomserver/storage/postgres/events_table.go | 13 ++++++ roomserver/storage/postgres/storage.go | 8 ++++ roomserver/storage/sqlite3/events_table.go | 13 ++++++ roomserver/storage/sqlite3/storage.go | 8 ++++ 8 files changed, 87 insertions(+), 14 deletions(-) diff --git a/roomserver/input/events.go b/roomserver/input/events.go index 7fbc5d8a9..8f9958068 100644 --- a/roomserver/input/events.go +++ b/roomserver/input/events.go @@ -72,6 +72,10 @@ type RoomEventDatabase interface { ctx context.Context, transactionID string, sessionID int64, userID string, ) (string, error) + // Look up the room version for a given room. + GetRoomVersionForRoom( + ctx context.Context, roomNID types.RoomNID, + ) (gomatrixserverlib.RoomVersion, error) } // OutputRoomEventWriter has the APIs needed to write an event to the output logs. diff --git a/roomserver/input/latest_events.go b/roomserver/input/latest_events.go index 9a99ad76f..4bb066e31 100644 --- a/roomserver/input/latest_events.go +++ b/roomserver/input/latest_events.go @@ -253,8 +253,10 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) latestEventIDs[i] = u.latest[i].EventID } - // TODO: Room version here - roomVersion := gomatrixserverlib.RoomVersionV1 + roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.roomNID) + if err != nil { + return nil, err + } ore := api.OutputNewRoomEvent{ Event: u.event.Headered(roomVersion), diff --git a/roomserver/query/query.go b/roomserver/query/query.go index 52b678ac3..107149ad3 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -93,6 +93,10 @@ type RoomserverQueryAPIDatabase interface { GetRoomVersionForRoom( ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) + // Look up the room NID that an event ID appears in. + GetRoomNIDForEventID( + ctx context.Context, eventID string, + ) (types.RoomNID, error) } // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI @@ -234,8 +238,15 @@ func (r *RoomserverQueryAPI) QueryEventsByID( } for _, event := range events { - // TODO: Room version here - roomVersion := gomatrixserverlib.RoomVersionV1 + roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID()) + if nerr != nil { + return nerr + } + + roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID) + if verr != nil { + return verr + } response.Events = append(response.Events, event.Headered(roomVersion)) } @@ -516,8 +527,15 @@ func (r *RoomserverQueryAPI) QueryMissingEvents( response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) for _, event := range loadedEvents { if !eventsToFilter[event.EventID()] { - // TODO: Room version here - roomVersion := gomatrixserverlib.RoomVersionV1 + roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID()) + if nerr != nil { + return nerr + } + + roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID) + if verr != nil { + return verr + } response.Events = append(response.Events, event.Headered(roomVersion)) } @@ -562,8 +580,15 @@ func (r *RoomserverQueryAPI) QueryBackfill( } for _, event := range loadedEvents { - // TODO: Room version here - roomVersion := gomatrixserverlib.RoomVersionV1 + roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID()) + if nerr != nil { + return nerr + } + + roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID) + if verr != nil { + return verr + } response.Events = append(response.Events, event.Headered(roomVersion)) } @@ -653,6 +678,11 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( } response.RoomExists = true + roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID) + if err != nil { + return err + } + prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) if err != nil { switch err.(type) { @@ -683,16 +713,10 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( } for _, event := range stateEvents { - // TODO: Room version here - roomVersion := gomatrixserverlib.RoomVersionV1 - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) } for _, event := range authEvents { - // TODO: Room version here - roomVersion := gomatrixserverlib.RoomVersionV1 - response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(roomVersion)) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7f32b53f8..37a833c55 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -38,6 +38,7 @@ type Database interface { GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error GetRoomIDForAlias(ctx context.Context, alias string) (string, error) + GetRoomNIDForEventID(ctx context.Context, eventID string) (types.RoomNID, error) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) RemoveRoomAlias(ctx context.Context, alias string) error diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index d9b269bc8..f33809f6c 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -116,6 +116,9 @@ const bulkSelectEventNIDSQL = "" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" +const selectRoomNIDForEventIDSQL = "" + + "SELECT room_nid FROM roomserver_events WHERE event_id = $1" + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -130,6 +133,7 @@ type eventStatements struct { bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt + selectRoomNIDForEventIDStmt *sql.Stmt } func (s *eventStatements) prepare(db *sql.DB) (err error) { @@ -152,6 +156,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, + {&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL}, }.prepare(db) } @@ -424,3 +429,11 @@ func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { } return nids } + +func (s *eventStatements) selectRoomNIDForEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) (roomNID types.RoomNID, err error) { + selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt) + err = selectStmt.QueryRowContext(ctx, eventID).Scan(&roomNID) + return +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index b2b4159c9..9f4e9df78 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -747,6 +747,14 @@ func (d *Database) GetRoomVersionForRoom( ) } +func (d *Database) GetRoomNIDForEventID( + ctx context.Context, eventID string, +) (types.RoomNID, error) { + return d.statements.selectRoomNIDForEventID( + ctx, nil, eventID, + ) +} + type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 4fa095913..14f2b94c1 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -96,6 +96,9 @@ const bulkSelectEventNIDSQL = "" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" +const selectRoomNIDForEventIDSQL = "" + + "SELECT room_nid FROM roomserver_events WHERE event_id = $1" + type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt @@ -111,6 +114,7 @@ type eventStatements struct { bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt + selectRoomNIDForEventIDStmt *sql.Stmt } func (s *eventStatements) prepare(db *sql.DB) (err error) { @@ -134,6 +138,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL}, }.prepare(db) } @@ -476,3 +481,11 @@ func eventNIDsAsArray(eventNIDs []types.EventNID) string { b, _ := json.Marshal(eventNIDs) return string(b) } + +func (s *eventStatements) selectRoomNIDForEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) (roomNID types.RoomNID, err error) { + selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt) + err = selectStmt.QueryRowContext(ctx, eventID).Scan(&roomNID) + return +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index b912b1c0e..ea082748e 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -902,6 +902,14 @@ func (d *Database) GetRoomVersionForRoom( ) } +func (d *Database) GetRoomNIDForEventID( + ctx context.Context, eventID string, +) (types.RoomNID, error) { + return d.statements.selectRoomNIDForEventID( + ctx, nil, eventID, + ) +} + type transaction struct { ctx context.Context txn *sql.Tx