diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 20db7ef7f..98df4708d 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -46,4 +46,5 @@ type Database interface { GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) + GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 6bb96f1de..0fd9d5b53 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -68,6 +68,9 @@ const updateLatestEventNIDsSQL = "" + const selectRoomVersionForRoomIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_id = $1" +const selectRoomVersionForRoomNIDSQL = "" + + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -75,6 +78,7 @@ type roomStatements struct { selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomIDStmt *sql.Stmt + selectRoomVersionForRoomNIDStmt *sql.Stmt } func (s *roomStatements) prepare(db *sql.DB) (err error) { @@ -89,6 +93,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, + {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, }.prepare(db) } @@ -173,3 +178,12 @@ func (s *roomStatements) selectRoomVersionForRoomID( err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) return roomVersion, err } + +func (s *roomStatements) selectRoomVersionForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + var roomVersion gomatrixserverlib.RoomVersion + stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + return roomVersion, err +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 9bb6de9d7..084b83ce6 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -747,6 +747,14 @@ func (d *Database) GetRoomVersionForRoom( ) } +func (d *Database) GetRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + return d.statements.selectRoomVersionForRoomNID( + ctx, nil, roomNID, + ) +} + type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 49fa07ea8..512b98137 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -57,6 +57,9 @@ const updateLatestEventNIDsSQL = "" + const selectRoomVersionForRoomIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_id = $1" +const selectRoomVersionForRoomNIDSQL = "" + + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -64,6 +67,7 @@ type roomStatements struct { selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomIDStmt *sql.Stmt + selectRoomVersionForRoomNIDStmt *sql.Stmt } func (s *roomStatements) prepare(db *sql.DB) (err error) { @@ -78,6 +82,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, + {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, }.prepare(db) } @@ -165,3 +170,12 @@ func (s *roomStatements) selectRoomVersionForRoomID( err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) return roomVersion, err } + +func (s *roomStatements) selectRoomVersionForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + var roomVersion gomatrixserverlib.RoomVersion + stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + return roomVersion, err +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index ae09a88ad..28d608ca5 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -901,6 +901,14 @@ func (d *Database) GetRoomVersionForRoom( ) } +func (d *Database) GetRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + return d.statements.selectRoomVersionForRoomNID( + ctx, nil, roomNID, + ) +} + type transaction struct { ctx context.Context txn *sql.Tx