diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index d7cf83c14..7616af43e 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -49,7 +49,7 @@ CREATE TABLE IF NOT EXISTS roomserver_rooms ( // Same as insertEventTypeNIDSQL const insertRoomNIDSQL = "" + - "INSERT INTO roomserver_rooms (room_id) VALUES ($1)" + + "INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)" + " ON CONFLICT ON CONSTRAINT roomserver_room_id_unique" + " DO NOTHING RETURNING (room_nid)" @@ -93,11 +93,11 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { } func (s *roomStatements) insertRoomNID( - ctx context.Context, txn *sql.Tx, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, roomVersion string, ) (types.RoomNID, error) { var roomNID int64 stmt := common.TxStmt(txn, s.insertRoomNIDStmt) - err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + err := stmt.QueryRowContext(ctx, roomID, roomVersion).Scan(&roomNID) return types.RoomNID(roomNID), err } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 785c069c2..0a56136da 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -18,6 +18,7 @@ package postgres import ( "context" "database/sql" + "encoding/json" // Import the postgres database driver. _ "github.com/lib/pq" @@ -70,7 +71,17 @@ func (d *Database) StoreEvent( } } - if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil { + roomVersion := "" + if event.Type() == gomatrixserverlib.MRoomCreate { + var createContent gomatrixserverlib.CreateContent + if err := json.Unmarshal(event.Content(), createContent); err == nil { + if createContent.RoomVersion != nil { + roomVersion = *createContent.RoomVersion + } + } + } + + if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID(), roomVersion); err != nil { return 0, types.StateAtEvent{}, err } @@ -123,13 +134,13 @@ func (d *Database) StoreEvent( } func (d *Database) assignRoomNID( - ctx context.Context, txn *sql.Tx, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, roomVersion string, ) (types.RoomNID, error) { // Check if we already have a numeric ID in the database. roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) + roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) if err == sql.ErrNoRows { // We raced with another insert so run the select again. roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) @@ -516,7 +527,7 @@ func (d *Database) MembershipUpdater( } }() - roomNID, err := d.assignRoomNID(ctx, txn, roomID) + roomNID, err := d.assignRoomNID(ctx, txn, roomID, "") if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 8f88806b2..7f1284323 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -38,7 +38,7 @@ const roomsSchema = ` // Same as insertEventTypeNIDSQL const insertRoomNIDSQL = ` - INSERT INTO roomserver_rooms (room_id) VALUES ($1) + INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2) ON CONFLICT DO NOTHING; ` @@ -82,11 +82,11 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { } func (s *roomStatements) insertRoomNID( - ctx context.Context, txn *sql.Tx, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, roomVersion string, ) (types.RoomNID, error) { var err error insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt) - if _, err = insertStmt.ExecContext(ctx, roomID); err == nil { + if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil { return s.selectRoomNID(ctx, txn, roomID) } else { return types.RoomNID(0), err diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index e05cecb14..7fbacfa39 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "errors" "net/url" @@ -91,7 +92,17 @@ func (d *Database) StoreEvent( } } - if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil { + roomVersion := "" + if event.Type() == gomatrixserverlib.MRoomCreate { + var createContent gomatrixserverlib.CreateContent + if err := json.Unmarshal(event.Content(), createContent); err == nil { + if createContent.RoomVersion != nil { + roomVersion = *createContent.RoomVersion + } + } + } + + if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { return err } @@ -151,13 +162,13 @@ func (d *Database) StoreEvent( } func (d *Database) assignRoomNID( - ctx context.Context, txn *sql.Tx, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, roomVersion string, ) (roomNID types.RoomNID, err error) { // Check if we already have a numeric ID in the database. roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) + roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) if err == nil { // Now get the numeric ID back out of the database roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) @@ -651,7 +662,7 @@ func (d *Database) MembershipUpdater( } }() - roomNID, err := d.assignRoomNID(ctx, txn, roomID) + roomNID, err := d.assignRoomNID(ctx, txn, roomID, "") if err != nil { return nil, err }