Merge branch 'master' into neilalexander/keydb

This commit is contained in:
Neil Alexander 2020-05-27 09:46:33 +01:00 committed by GitHub
commit d3baee1d37
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 668 additions and 869 deletions

View file

@ -1,6 +1,6 @@
## Peer-to-peer Matrix ## Peer-to-peer Matrix
These are the instructions for setting up P2P Dendrite, current as of March 2020. There's both Go stuff and JS stuff to do to set this up. These are the instructions for setting up P2P Dendrite, current as of May 2020. There's both Go stuff and JS stuff to do to set this up.
### Dendrite ### Dendrite
@ -28,14 +28,13 @@ Then use `/ip4/127.0.0.1/tcp/9090/ws/p2p-websocket-star/`.
### Riot-web ### Riot-web
You need to check out these repos: You need to check out this repo:
``` ```
$ git clone git@github.com:matrix-org/go-http-js-libp2p.git $ git clone git@github.com:matrix-org/go-http-js-libp2p.git
$ git clone git@github.com:matrix-org/go-sqlite3-js.git
``` ```
Make sure to `yarn install` in both of these repos. Then: Make sure to `yarn install` in the repo. Then:
- `$ cp "$(go env GOROOT)/misc/wasm/wasm_exec.js" ./src/vector/` - `$ cp "$(go env GOROOT)/misc/wasm/wasm_exec.js" ./src/vector/`
- Comment out the lines in `wasm_exec.js` which contains: - Comment out the lines in `wasm_exec.js` which contains:
@ -49,7 +48,6 @@ if (!global.fs && global.require) {
- Add the following symlinks: they HAVE to be symlinks as the diff in `webpack.config.js` references specific paths. - Add the following symlinks: they HAVE to be symlinks as the diff in `webpack.config.js` references specific paths.
``` ```
$ cd node_modules $ cd node_modules
$ ln -s ../../go-sqlite-js # NB: NOT go-sqlite3-js
$ ln -s ../../go-http-js-libp2p $ ln -s ../../go-http-js-libp2p
``` ```
@ -65,14 +63,7 @@ You need a Chrome and a Firefox running to test locally as service workers don't
Assuming you've `yarn start`ed Riot-Web, go to `http://localhost:8080` and register with `http://localhost:8080` as your HS URL. Assuming you've `yarn start`ed Riot-Web, go to `http://localhost:8080` and register with `http://localhost:8080` as your HS URL.
You can join rooms by room alias e.g `/join #foo:bar`. You can:
- join rooms by room alias e.g `/join #foo:bar`.
### Known issues - invite specific users to a room.
- explore the published room list. All members of the room can re-publish aliases (unlike Synapse).
- When registering you may be unable to find the server, it'll seem flakey. This happens because the SW, particularly in Firefox,
gets killed after 30s of inactivity. When you are not registered, you aren't doing `/sync` calls to keep the SW alive, so if you
don't register for a while and idle on the page, the HS will disappear. To fix, unregister the SW, and then refresh the page.
- The libp2p layer has rate limits, so frequent Federation traffic may cause the connection to drop and messages to not be transferred.
I guess in other words, don't send too much traffic?

View file

@ -426,10 +426,9 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
} }
func (s *eventStatements) SelectRoomNIDForEventNID( func (s *eventStatements) SelectRoomNIDForEventNID(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ctx context.Context, eventNID types.EventNID,
) (roomNID types.RoomNID, err error) { ) (roomNID types.RoomNID, err error) {
selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID)
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID)
return return
} }

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -63,19 +64,20 @@ type previousEventStatements struct {
selectPreviousEventExistsStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt
} }
func (s *previousEventStatements) prepare(db *sql.DB) (err error) { func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
_, err = db.Exec(previousEventSchema) s := &previousEventStatements{}
_, err := db.Exec(previousEventSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.prepare(db) }.prepare(db)
} }
func (s *previousEventStatements) insertPreviousEvent( func (s *previousEventStatements) InsertPreviousEvent(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
previousEventID string, previousEventID string,
@ -91,7 +93,7 @@ func (s *previousEventStatements) insertPreviousEvent(
// Check if the event reference exists // Check if the event reference exists
// Returns sql.ErrNoRows if the event reference doesn't exist. // Returns sql.ErrNoRows if the event reference doesn't exist.
func (s *previousEventStatements) selectPreviousEventExists( func (s *previousEventStatements) SelectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error { ) error {
var ok int64 var ok int64

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
const roomAliasesSchema = ` const roomAliasesSchema = `
@ -59,12 +60,13 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt deleteRoomAliasStmt *sql.Stmt
} }
func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
_, err = db.Exec(roomAliasesSchema) s := &roomAliasesStatements{}
_, err := db.Exec(roomAliasesSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.insertRoomAliasStmt, insertRoomAliasSQL},
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
@ -73,14 +75,14 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *roomAliasesStatements) insertRoomAlias( func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, alias string, roomID string, creatorUserID string, ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) { ) (err error) {
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
return return
} }
func (s *roomAliasesStatements) selectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
@ -90,7 +92,7 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias(
return return
} }
func (s *roomAliasesStatements) selectAliasesFromRoomID( func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
@ -111,7 +113,7 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(
return aliases, rows.Err() return aliases, rows.Err()
} }
func (s *roomAliasesStatements) selectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
@ -121,7 +123,7 @@ func (s *roomAliasesStatements) selectCreatorIDFromAlias(
return return
} }
func (s *roomAliasesStatements) deleteRoomAlias( func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (err error) { ) (err error) {
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)

View file

@ -22,6 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -82,12 +83,13 @@ type roomStatements struct {
selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt
} }
func (s *roomStatements) prepare(db *sql.DB) (err error) { func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
_, err = db.Exec(roomsSchema) s := &roomStatements{}
_, err := db.Exec(roomsSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.insertRoomNIDStmt, insertRoomNIDSQL},
{&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL},
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
@ -98,7 +100,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *roomStatements) insertRoomNID( func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
@ -108,7 +110,7 @@ func (s *roomStatements) insertRoomNID(
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) selectRoomNID( func (s *roomStatements) SelectRoomNID(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
var roomNID int64 var roomNID int64
@ -117,8 +119,8 @@ func (s *roomStatements) selectRoomNID(
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) selectLatestEventNIDs( func (s *roomStatements) SelectLatestEventNIDs(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array var nids pq.Int64Array
var stateSnapshotNID int64 var stateSnapshotNID int64
@ -134,7 +136,7 @@ func (s *roomStatements) selectLatestEventNIDs(
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
} }
func (s *roomStatements) selectLatestEventsNIDsForUpdate( func (s *roomStatements) SelectLatestEventsNIDsForUpdate(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array var nids pq.Int64Array
@ -152,7 +154,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(
return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
} }
func (s *roomStatements) updateLatestEventNIDs( func (s *roomStatements) UpdateLatestEventNIDs(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
roomNID types.RoomNID, roomNID types.RoomNID,
@ -171,7 +173,7 @@ func (s *roomStatements) updateLatestEventNIDs(
return err return err
} }
func (s *roomStatements) selectRoomVersionForRoomID( func (s *roomStatements) SelectRoomVersionForRoomID(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (gomatrixserverlib.RoomVersion, error) { ) (gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion var roomVersion gomatrixserverlib.RoomVersion
@ -183,12 +185,11 @@ func (s *roomStatements) selectRoomVersionForRoomID(
return roomVersion, err return roomVersion, err
} }
func (s *roomStatements) selectRoomVersionForRoomNID( func (s *roomStatements) SelectRoomVersionForRoomNID(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) { ) (gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion var roomVersion gomatrixserverlib.RoomVersion
stmt := internal.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion)
err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return roomVersion, errors.New("room not found") return roomVersion, errors.New("room not found")
} }

View file

@ -38,14 +38,8 @@ func (s *statements) prepare(db *sql.DB) error {
var err error var err error
for _, prepare := range []func(db *sql.DB) error{ for _, prepare := range []func(db *sql.DB) error{
s.roomStatements.prepare,
s.stateSnapshotStatements.prepare,
s.stateBlockStatements.prepare,
s.previousEventStatements.prepare,
s.roomAliasesStatements.prepare,
s.inviteStatements.prepare, s.inviteStatements.prepare,
s.membershipStatements.prepare, s.membershipStatements.prepare,
s.transactionStatements.prepare,
} { } {
if err = prepare(db); err != nil { if err = prepare(db); err != nil {
return err return err

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -87,13 +88,14 @@ type stateBlockStatements struct {
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
} }
func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
_, err = db.Exec(stateDataSchema) s := &stateBlockStatements{}
_, err := db.Exec(stateDataSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertStateDataStmt, insertStateDataSQL}, {&s.insertStateDataStmt, insertStateDataSQL},
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
@ -101,11 +103,15 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *stateBlockStatements) bulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, ctx context.Context,
stateBlockNID types.StateBlockNID, txn *sql.Tx,
entries []types.StateEntry, entries []types.StateEntry,
) error { ) (types.StateBlockNID, error) {
stateBlockNID, err := s.selectNextStateBlockNID(ctx)
if err != nil {
return 0, err
}
for _, entry := range entries { for _, entry := range entries {
_, err := s.insertStateDataStmt.ExecContext( _, err := s.insertStateDataStmt.ExecContext(
ctx, ctx,
@ -115,10 +121,10 @@ func (s *stateBlockStatements) bulkInsertStateData(
int64(entry.EventNID), int64(entry.EventNID),
) )
if err != nil { if err != nil {
return err return 0, err
} }
} }
return nil return stateBlockNID, nil
} }
func (s *stateBlockStatements) selectNextStateBlockNID( func (s *stateBlockStatements) selectNextStateBlockNID(
@ -129,7 +135,7 @@ func (s *stateBlockStatements) selectNextStateBlockNID(
return types.StateBlockNID(stateBlockNID), err return types.StateBlockNID(stateBlockNID), err
} }
func (s *stateBlockStatements) bulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID, ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
nids := make([]int64, len(stateBlockNIDs)) nids := make([]int64, len(stateBlockNIDs))
@ -180,7 +186,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
return results, err return results, err
} }
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
ctx context.Context, ctx context.Context,
stateBlockNIDs []types.StateBlockNID, stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,

View file

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -64,30 +65,31 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt
} }
func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
_, err = db.Exec(stateSnapshotSchema) s := &stateSnapshotStatements{}
_, err := db.Exec(stateSnapshotSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertStateStmt, insertStateSQL}, {&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
}.prepare(db) }.prepare(db)
} }
func (s *stateSnapshotStatements) insertState( func (s *stateSnapshotStatements) InsertState(
ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
nids := make([]int64, len(stateBlockNIDs)) nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i]) nids[i] = int64(stateBlockNIDs[i])
} }
err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
return return
} }
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs)) nids := make([]int64, len(stateNIDs))

View file

@ -18,14 +18,12 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -40,10 +38,14 @@ type Database struct {
eventTypes tables.EventTypes eventTypes tables.EventTypes
eventStateKeys tables.EventStateKeys eventStateKeys tables.EventStateKeys
eventJSON tables.EventJSON eventJSON tables.EventJSON
rooms tables.Rooms
transactions tables.Transactions
prevEvents tables.PreviousEvents
db *sql.DB db *sql.DB
} }
// Open a postgres database. // Open a postgres database.
// nolint: gocyclo
func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) { func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) {
var d Database var d Database
var err error var err error
@ -69,164 +71,63 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.rooms, err = NewPostgresRoomsTable(d.db)
if err != nil {
return nil, err
}
d.transactions, err = NewPostgresTransactionsTable(d.db)
if err != nil {
return nil, err
}
stateBlock, err := NewPostgresStateBlockTable(d.db)
if err != nil {
return nil, err
}
stateSnapshot, err := NewPostgresStateSnapshotTable(d.db)
if err != nil {
return nil, err
}
roomAliases, err := NewPostgresRoomAliasesTable(d.db)
if err != nil {
return nil, err
}
d.prevEvents, err = NewPostgresPreviousEventsTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db,
EventTypesTable: d.eventTypes, EventTypesTable: d.eventTypes,
EventStateKeysTable: d.eventStateKeys, EventStateKeysTable: d.eventStateKeys,
EventJSONTable: d.eventJSON, EventJSONTable: d.eventJSON,
EventsTable: d.events, EventsTable: d.events,
RoomsTable: d.rooms,
TransactionsTable: d.transactions,
StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot,
PrevEventsTable: d.prevEvents,
RoomAliasesTable: roomAliases,
} }
return &d, nil return &d, nil
} }
// StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
eventStateKeyNID types.EventStateKeyNID
eventNID types.EventNID
stateNID types.StateSnapshotNID
err error
)
if txnAndSessionID != nil {
if err = d.statements.insertTransaction(
ctx, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return 0, types.StateAtEvent{}, err
}
}
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
// Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
// Note that the below logic depends on the m.room.create event being the
// first event that is persisted to the database when creating or joining a
// room.
var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return 0, types.StateAtEvent{}, err
}
if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID(), roomVersion); err != nil {
return 0, types.StateAtEvent{}, err
}
if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil {
return 0, types.StateAtEvent{}, err
}
eventStateKey := event.StateKey()
// Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil {
return 0, types.StateAtEvent{}, err
}
}
if eventNID, stateNID, err = d.events.InsertEvent(
ctx,
nil,
roomNID,
eventTypeNID,
eventStateKeyNID,
event.EventID(),
event.EventReference().EventSHA256,
authEventNIDs,
event.Depth(),
); err != nil {
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.events.SelectEvent(ctx, nil, event.EventID())
}
if err != nil {
return 0, types.StateAtEvent{}, err
}
}
if err = d.eventJSON.InsertEventJSON(ctx, nil, eventNID, event.JSON()); err != nil {
return 0, types.StateAtEvent{}, err
}
return roomNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID,
StateEntry: types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypeNID,
EventStateKeyNID: eventStateKeyNID,
},
EventNID: eventNID,
},
}, nil
}
func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) (
gomatrixserverlib.RoomVersion, error,
) {
var err error
var roomVersion gomatrixserverlib.RoomVersion
// Look for m.room.create events.
if event.Type() != gomatrixserverlib.MRoomCreate {
return gomatrixserverlib.RoomVersion(""), nil
}
roomVersion = gomatrixserverlib.RoomVersionV1
var createContent gomatrixserverlib.CreateContent
// The m.room.create event contains an optional "room_version" key in
// the event content, so we need to unmarshal that first.
if err = json.Unmarshal(event.Content(), &createContent); err != nil {
return gomatrixserverlib.RoomVersion(""), err
}
// A room version was specified in the event content?
if createContent.RoomVersion != nil {
roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion)
}
return roomVersion, err
}
func (d *Database) assignRoomNID( func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) roomNID, err := d.rooms.SelectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database. // We don't have a numeric ID so insert one into the database.
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) roomNID, err = d.rooms.InsertRoomNID(ctx, txn, roomID, roomVersion)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We raced with another insert so run the select again. // We raced with another insert so run the select again.
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) roomNID, err = d.rooms.SelectRoomNID(ctx, txn, roomID)
} }
} }
return roomNID, err return roomNID, err
} }
func (d *Database) assignEventTypeNID(
ctx context.Context, eventType string,
) (eventTypeNID types.EventTypeNID, err error) {
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
// Check if we already have a numeric ID in the database.
eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventTypeNID, err = d.eventTypes.InsertEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType)
}
}
return err
})
return eventTypeNID, err
}
func (d *Database) assignStateKeyNID( func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
@ -243,73 +144,6 @@ func (d *Database) assignStateKeyNID(
return eventStateKeyNID, err return eventStateKeyNID, err
} }
// Events implements input.EventDatabase
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs)
if err != nil {
return nil, err
}
results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
result := &results[i]
result.EventNID = eventJSON.EventNID
roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID)
if err != nil {
return nil, err
}
roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, nil, roomNID)
if err != nil {
return nil, err
}
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON(
eventJSON.EventJSON, false, roomVersion,
)
if err != nil {
return nil, err
}
}
return results, nil
}
// AddState implements input.EventDatabase
func (d *Database) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (types.StateSnapshotNID, error) {
if len(state) > 0 {
stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx)
if err != nil {
return 0, err
}
if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil {
return 0, err
}
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
}
return d.statements.insertState(ctx, roomNID, stateBlockNIDs)
}
// StateBlockNIDs implements state.RoomStateDatabase
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs)
}
// StateEntries implements state.RoomStateDatabase
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs)
}
// GetLatestEventsForUpdate implements input.EventDatabase // GetLatestEventsForUpdate implements input.EventDatabase
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,
@ -319,7 +153,7 @@ func (d *Database) GetLatestEventsForUpdate(
return nil, err return nil, err
} }
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) d.rooms.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
if err != nil { if err != nil {
txn.Rollback() // nolint: errcheck txn.Rollback() // nolint: errcheck
return nil, err return nil, err
@ -342,18 +176,6 @@ func (d *Database) GetLatestEventsForUpdate(
}, nil }, nil
} }
// GetTransactionEventID implements input.EventDatabase
func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string,
sessionID int64, userID string,
) (string, error) {
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
if err == sql.ErrNoRows {
return "", nil
}
return eventID, err
}
type roomRecentEventsUpdater struct { type roomRecentEventsUpdater struct {
transaction transaction
d *Database d *Database
@ -387,7 +209,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
// StorePreviousEvents implements types.RoomRecentEventsUpdater // StorePreviousEvents implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences { for _, ref := range previousEventReferences {
if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { if err := u.d.prevEvents.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return err return err
} }
} }
@ -396,7 +218,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) err := u.d.prevEvents.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil { if err == nil {
return true, nil return true, nil
} }
@ -415,7 +237,7 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
for i := range latest { for i := range latest {
eventNIDs[i] = latest[i].EventNID eventNIDs[i] = latest[i].EventNID
} }
return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) return u.d.rooms.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
} }
// HasEventBeenSent implements types.RoomRecentEventsUpdater // HasEventBeenSent implements types.RoomRecentEventsUpdater
@ -432,55 +254,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
} }
// RoomNID implements query.RoomserverQueryAPIDB
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID)
if err == sql.ErrNoRows {
return 0, nil
}
return roomNID, err
}
// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB
func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
roomNID, err = d.RoomNID(ctx, roomID)
if err != nil {
return
}
latestEvents, _, err := d.statements.selectLatestEventNIDs(ctx, roomNID)
if err != nil {
return
}
if len(latestEvents) == 0 {
roomNID = 0
return
}
return
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
var eventNIDs []types.EventNID
eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, roomNID)
if err != nil {
return err
}
references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs)
if err != nil {
return err
}
depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs)
if err != nil {
return err
}
return nil
})
return
}
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase // GetInvitesForUser implements query.RoomserverQueryAPIDatabase
func (d *Database) GetInvitesForUser( func (d *Database) GetInvitesForUser(
ctx context.Context, ctx context.Context,
@ -490,44 +263,6 @@ func (d *Database) GetInvitesForUser(
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
} }
// SetRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID)
}
// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.statements.selectRoomIDFromAlias(ctx, alias)
}
// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.statements.selectAliasesFromRoomID(ctx, roomID)
}
// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
return d.statements.selectCreatorIDFromAlias(ctx, alias)
}
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
return d.statements.deleteRoomAlias(ctx, alias)
}
// StateEntriesForTuples implements state.RoomStateDatabase
func (d *Database) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.statements.bulkSelectFilteredStateBlockEntries(
ctx, stateBlockNIDs, stateKeyTuples,
)
}
// MembershipUpdater implements input.RoomEventDatabase // MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string, ctx context.Context, roomID, targetUserID string,
@ -733,37 +468,6 @@ func (d *Database) GetMembershipEventNIDsForRoom(
return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly)
} }
// EventsFromIDs implements query.RoomserverQueryAPIEventDB
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
var nids []types.EventNID
for _, nid := range nidMap {
nids = append(nids, nid)
}
return d.Events(ctx, nids)
}
func (d *Database) GetRoomVersionForRoom(
ctx context.Context, roomID string,
) (gomatrixserverlib.RoomVersion, error) {
return d.statements.selectRoomVersionForRoomID(
ctx, nil, roomID,
)
}
func (d *Database) GetRoomVersionForRoomNID(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) {
return d.statements.selectRoomVersionForRoomNID(
ctx, nil, roomNID,
)
}
type transaction struct { type transaction struct {
ctx context.Context ctx context.Context
txn *sql.Tx txn *sql.Tx

View file

@ -18,6 +18,8 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
const transactionsSchema = ` const transactionsSchema = `
@ -51,20 +53,21 @@ type transactionStatements struct {
selectTransactionEventIDStmt *sql.Stmt selectTransactionEventIDStmt *sql.Stmt
} }
func (s *transactionStatements) prepare(db *sql.DB) (err error) { func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) {
_, err = db.Exec(transactionsSchema) s := &transactionStatements{}
_, err := db.Exec(transactionsSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertTransactionStmt, insertTransactionSQL}, {&s.insertTransactionStmt, insertTransactionSQL},
{&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
}.prepare(db) }.prepare(db)
} }
func (s *transactionStatements) insertTransaction( func (s *transactionStatements) InsertTransaction(
ctx context.Context, ctx context.Context, txn *sql.Tx,
transactionID string, transactionID string,
sessionID int64, sessionID int64,
userID string, userID string,
@ -76,7 +79,7 @@ func (s *transactionStatements) insertTransaction(
return return
} }
func (s *transactionStatements) selectTransactionEventID( func (s *transactionStatements) SelectTransactionEventID(
ctx context.Context, ctx context.Context,
transactionID string, transactionID string,
sessionID int64, sessionID int64,

View file

@ -2,16 +2,28 @@ package shared
import ( import (
"context" "context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
) )
type Database struct { type Database struct {
DB *sql.DB
EventsTable tables.Events EventsTable tables.Events
EventJSONTable tables.EventJSON EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms
TransactionsTable tables.Transactions
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents
} }
// EventTypeNIDs implements state.RoomStateDatabase // EventTypeNIDs implements state.RoomStateDatabase
@ -42,6 +54,42 @@ func (d *Database) StateEntriesForEventIDs(
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
} }
// StateEntriesForTuples implements state.RoomStateDatabase
func (d *Database) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.StateBlockTable.BulkSelectFilteredStateBlockEntries(
ctx, stateBlockNIDs, stateKeyTuples,
)
}
// AddState implements input.EventDatabase
func (d *Database) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error {
if len(state) > 0 {
var stateBlockNID types.StateBlockNID
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)
if err != nil {
return err
}
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
}
stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs)
return err
})
if err != nil {
return 0, err
}
return
}
// EventNIDs implements query.RoomserverQueryAPIDatabase // EventNIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) EventNIDs( func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
@ -77,3 +125,337 @@ func (d *Database) EventIDs(
) (map[types.EventNID]string, error) { ) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) return d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
} }
// EventsFromIDs implements query.RoomserverQueryAPIEventDB
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
var nids []types.EventNID
for _, nid := range nidMap {
nids = append(nids, nid)
}
return d.Events(ctx, nids)
}
// RoomNID implements query.RoomserverQueryAPIDB
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, nil, roomID)
if err == sql.ErrNoRows {
return 0, nil
}
return roomNID, err
}
// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB
func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
roomNID, err = d.RoomNID(ctx, roomID)
if err != nil {
return
}
latestEvents, _, err := d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID)
if err != nil {
return
}
if len(latestEvents) == 0 {
roomNID = 0
return
}
return
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error {
var eventNIDs []types.EventNID
eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID)
if err != nil {
return err
}
references, err = d.EventsTable.BulkSelectEventReference(ctx, txn, eventNIDs)
if err != nil {
return err
}
depth, err = d.EventsTable.SelectMaxEventDepth(ctx, txn, eventNIDs)
if err != nil {
return err
}
return nil
})
return
}
// StateBlockNIDs implements state.RoomStateDatabase
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs)
}
// StateEntries implements state.RoomStateDatabase
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
}
func (d *Database) GetRoomVersionForRoom(
ctx context.Context, roomID string,
) (gomatrixserverlib.RoomVersion, error) {
return d.RoomsTable.SelectRoomVersionForRoomID(
ctx, nil, roomID,
)
}
func (d *Database) GetRoomVersionForRoomNID(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) {
return d.RoomsTable.SelectRoomVersionForRoomNID(
ctx, roomNID,
)
}
// SetRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID)
}
// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias)
}
// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID)
}
// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias)
}
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias)
}
// Events implements input.EventDatabase
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
if err != nil {
return nil, err
}
results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
result := &results[i]
result.EventNID = eventJSON.EventNID
roomNID, err = d.EventsTable.SelectRoomNIDForEventNID(ctx, eventJSON.EventNID)
if err != nil {
return nil, err
}
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return nil, err
}
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON(
eventJSON.EventJSON, false, roomVersion,
)
if err != nil {
return nil, err
}
}
return results, nil
}
// GetTransactionEventID implements input.EventDatabase
func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string,
sessionID int64, userID string,
) (string, error) {
eventID, err := d.TransactionsTable.SelectTransactionEventID(ctx, transactionID, sessionID, userID)
if err == sql.ErrNoRows {
return "", nil
}
return eventID, err
}
// StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
eventStateKeyNID types.EventStateKeyNID
eventNID types.EventNID
stateNID types.StateSnapshotNID
err error
)
err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error {
if txnAndSessionID != nil {
if err = d.TransactionsTable.InsertTransaction(
ctx, txn, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return err
}
}
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
// Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
// Note that the below logic depends on the m.room.create event being the
// first event that is persisted to the database when creating or joining a
// room.
var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return err
}
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil {
return err
}
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
return err
}
eventStateKey := event.StateKey()
// Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
return err
}
}
if eventNID, stateNID, err = d.EventsTable.InsertEvent(
ctx,
txn,
roomNID,
eventTypeNID,
eventStateKeyNID,
event.EventID(),
event.EventReference().EventSHA256,
authEventNIDs,
event.Depth(),
); err != nil {
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID())
}
if err != nil {
return err
}
}
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return err
}
return nil
})
if err != nil {
return 0, types.StateAtEvent{}, err
}
return roomNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID,
StateEntry: types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypeNID,
EventStateKeyNID: eventStateKeyNID,
},
EventNID: eventNID,
},
}, nil
}
func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) {
// Check if we already have a numeric ID in the database.
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
roomNID, err = d.RoomsTable.InsertRoomNID(ctx, txn, roomID, roomVersion)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
}
}
return roomNID, err
}
func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (eventTypeNID types.EventTypeNID, err error) {
// Check if we already have a numeric ID in the database.
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventTypeNID, err = d.EventTypesTable.InsertEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
}
}
return
}
func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
// Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventStateKeyNID, err = d.EventStateKeysTable.InsertEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
}
}
return eventStateKeyNID, err
}
func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) (
gomatrixserverlib.RoomVersion, error,
) {
var err error
var roomVersion gomatrixserverlib.RoomVersion
// Look for m.room.create events.
if event.Type() != gomatrixserverlib.MRoomCreate {
return gomatrixserverlib.RoomVersion(""), nil
}
roomVersion = gomatrixserverlib.RoomVersionV1
var createContent gomatrixserverlib.CreateContent
// The m.room.create event contains an optional "room_version" key in
// the event content, so we need to unmarshal that first.
if err = json.Unmarshal(event.Content(), &createContent); err != nil {
return gomatrixserverlib.RoomVersion(""), err
}
// A room version was specified in the event content?
if createContent.RoomVersion != nil {
roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion)
}
return roomVersion, err
}

View file

@ -469,10 +469,9 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
} }
func (s *eventStatements) SelectRoomNIDForEventNID( func (s *eventStatements) SelectRoomNIDForEventNID(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ctx context.Context, eventNID types.EventNID,
) (roomNID types.RoomNID, err error) { ) (roomNID types.RoomNID, err error) {
selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID)
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID)
return return
} }

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -55,19 +56,20 @@ type previousEventStatements struct {
selectPreviousEventExistsStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt
} }
func (s *previousEventStatements) prepare(db *sql.DB) (err error) { func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
_, err = db.Exec(previousEventSchema) s := &previousEventStatements{}
_, err := db.Exec(previousEventSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.prepare(db) }.prepare(db)
} }
func (s *previousEventStatements) insertPreviousEvent( func (s *previousEventStatements) InsertPreviousEvent(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
previousEventID string, previousEventID string,
@ -83,7 +85,7 @@ func (s *previousEventStatements) insertPreviousEvent(
// Check if the event reference exists // Check if the event reference exists
// Returns sql.ErrNoRows if the event reference doesn't exist. // Returns sql.ErrNoRows if the event reference doesn't exist.
func (s *previousEventStatements) selectPreviousEventExists( func (s *previousEventStatements) SelectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error { ) error {
var ok int64 var ok int64

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
const roomAliasesSchema = ` const roomAliasesSchema = `
@ -60,12 +61,13 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt deleteRoomAliasStmt *sql.Stmt
} }
func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
_, err = db.Exec(roomAliasesSchema) s := &roomAliasesStatements{}
_, err := db.Exec(roomAliasesSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.insertRoomAliasStmt, insertRoomAliasSQL},
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
@ -74,31 +76,28 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *roomAliasesStatements) insertRoomAlias( func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) { ) (err error) {
insertStmt := internal.TxStmt(txn, s.insertRoomAliasStmt) _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
_, err = insertStmt.ExecContext(ctx, alias, roomID, creatorUserID)
return return
} }
func (s *roomAliasesStatements) selectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, txn *sql.Tx, alias string, ctx context.Context, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
selectStmt := internal.TxStmt(txn, s.selectRoomIDFromAliasStmt) err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
err = selectStmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
return return
} }
func (s *roomAliasesStatements) selectAliasesFromRoomID( func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, roomID string,
) (aliases []string, err error) { ) (aliases []string, err error) {
aliases = []string{} aliases = []string{}
selectStmt := internal.TxStmt(txn, s.selectAliasesFromRoomIDStmt) rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
rows, err := selectStmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }
@ -117,21 +116,19 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(
return return
} }
func (s *roomAliasesStatements) selectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, txn *sql.Tx, alias string, ctx context.Context, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
selectStmt := internal.TxStmt(txn, s.selectCreatorIDFromAliasStmt) err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
err = selectStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
return return
} }
func (s *roomAliasesStatements) deleteRoomAlias( func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, txn *sql.Tx, alias string, ctx context.Context, alias string,
) (err error) { ) (err error) {
deleteStmt := internal.TxStmt(txn, s.deleteRoomAliasStmt) _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
_, err = deleteStmt.ExecContext(ctx, alias)
return return
} }

View file

@ -22,6 +22,7 @@ import (
"errors" "errors"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -71,12 +72,13 @@ type roomStatements struct {
selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt
} }
func (s *roomStatements) prepare(db *sql.DB) (err error) { func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
_, err = db.Exec(roomsSchema) s := &roomStatements{}
_, err := db.Exec(roomsSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.insertRoomNIDStmt, insertRoomNIDSQL},
{&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL},
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
@ -87,20 +89,20 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *roomStatements) insertRoomNID( func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
var err error var err error
insertStmt := internal.TxStmt(txn, s.insertRoomNIDStmt) insertStmt := internal.TxStmt(txn, s.insertRoomNIDStmt)
if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil { if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil {
return s.selectRoomNID(ctx, txn, roomID) return s.SelectRoomNID(ctx, txn, roomID)
} else { } else {
return types.RoomNID(0), err return types.RoomNID(0), err
} }
} }
func (s *roomStatements) selectRoomNID( func (s *roomStatements) SelectRoomNID(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
var roomNID int64 var roomNID int64
@ -109,7 +111,7 @@ func (s *roomStatements) selectRoomNID(
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) selectLatestEventNIDs( func (s *roomStatements) SelectLatestEventNIDs(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.StateSnapshotNID, error) {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
@ -126,7 +128,7 @@ func (s *roomStatements) selectLatestEventNIDs(
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
} }
func (s *roomStatements) selectLatestEventsNIDsForUpdate( func (s *roomStatements) SelectLatestEventsNIDsForUpdate(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
@ -144,7 +146,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(
return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
} }
func (s *roomStatements) updateLatestEventNIDs( func (s *roomStatements) UpdateLatestEventNIDs(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
roomNID types.RoomNID, roomNID types.RoomNID,
@ -163,7 +165,7 @@ func (s *roomStatements) updateLatestEventNIDs(
return err return err
} }
func (s *roomStatements) selectRoomVersionForRoomID( func (s *roomStatements) SelectRoomVersionForRoomID(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (gomatrixserverlib.RoomVersion, error) { ) (gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion var roomVersion gomatrixserverlib.RoomVersion
@ -175,12 +177,11 @@ func (s *roomStatements) selectRoomVersionForRoomID(
return roomVersion, err return roomVersion, err
} }
func (s *roomStatements) selectRoomVersionForRoomNID( func (s *roomStatements) SelectRoomVersionForRoomNID(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) { ) (gomatrixserverlib.RoomVersion, error) {
var roomVersion gomatrixserverlib.RoomVersion var roomVersion gomatrixserverlib.RoomVersion
stmt := internal.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion)
err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return roomVersion, errors.New("room not found") return roomVersion, errors.New("room not found")
} }

View file

@ -38,14 +38,8 @@ func (s *statements) prepare(db *sql.DB) error {
var err error var err error
for _, prepare := range []func(db *sql.DB) error{ for _, prepare := range []func(db *sql.DB) error{
s.roomStatements.prepare,
s.stateSnapshotStatements.prepare,
s.stateBlockStatements.prepare,
s.previousEventStatements.prepare,
s.roomAliasesStatements.prepare,
s.inviteStatements.prepare, s.inviteStatements.prepare,
s.membershipStatements.prepare, s.membershipStatements.prepare,
s.transactionStatements.prepare,
} { } {
if err = prepare(db); err != nil { if err = prepare(db); err != nil {
return err return err

View file

@ -23,6 +23,7 @@ import (
"strings" "strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -77,14 +78,15 @@ type stateBlockStatements struct {
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
} }
func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{}
s.db = db s.db = db
_, err = db.Exec(stateDataSchema) _, err := db.Exec(stateDataSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertStateDataStmt, insertStateDataSQL}, {&s.insertStateDataStmt, insertStateDataSQL},
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
@ -92,7 +94,7 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *stateBlockStatements) bulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
entries []types.StateEntry, entries []types.StateEntry,
) (types.StateBlockNID, error) { ) (types.StateBlockNID, error) {
@ -120,19 +122,18 @@ func (s *stateBlockStatements) bulkInsertStateData(
return stateBlockNID, nil return stateBlockNID, nil
} }
func (s *stateBlockStatements) bulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
nids := make([]interface{}, len(stateBlockNIDs)) nids := make([]interface{}, len(stateBlockNIDs))
for k, v := range stateBlockNIDs { for k, v := range stateBlockNIDs {
nids[k] = v nids[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", internal.QueryVariadic(len(nids)), 1) selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", internal.QueryVariadic(len(nids)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectStmt, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt := internal.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, nids...) rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -174,8 +175,8 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
return results, nil return results, nil
} }
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
ctx context.Context, txn *sql.Tx, // nolint: unparam ctx context.Context,
stateBlockNIDs []types.StateBlockNID, stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {

View file

@ -23,6 +23,7 @@ import (
"strings" "strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -51,20 +52,21 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt
} }
func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{}
s.db = db s.db = db
_, err = db.Exec(stateSnapshotSchema) _, err := db.Exec(stateSnapshotSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertStateStmt, insertStateSQL}, {&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
}.prepare(db) }.prepare(db)
} }
func (s *stateSnapshotStatements) insertState( func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
@ -82,15 +84,15 @@ func (s *stateSnapshotStatements) insertState(
return return
} }
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs)) nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs { for k, v := range stateNIDs {
nids[k] = v nids[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", internal.QueryVariadic(len(nids)), 1) selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", internal.QueryVariadic(len(nids)), 1)
selectStmt, err := txn.Prepare(selectOrig) selectStmt, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -18,14 +18,12 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"net/url" "net/url"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -41,10 +39,14 @@ type Database struct {
eventJSON tables.EventJSON eventJSON tables.EventJSON
eventTypes tables.EventTypes eventTypes tables.EventTypes
eventStateKeys tables.EventStateKeys eventStateKeys tables.EventStateKeys
rooms tables.Rooms
transactions tables.Transactions
prevEvents tables.PreviousEvents
db *sql.DB db *sql.DB
} }
// Open a sqlite database. // Open a sqlite database.
// nolint: gocyclo
func Open(dataSourceName string) (*Database, error) { func Open(dataSourceName string) (*Database, error) {
var d Database var d Database
uri, err := url.Parse(dataSourceName) uri, err := url.Parse(dataSourceName)
@ -89,163 +91,58 @@ func Open(dataSourceName string) (*Database, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.rooms, err = NewSqliteRoomsTable(d.db)
if err != nil {
return nil, err
}
d.transactions, err = NewSqliteTransactionsTable(d.db)
if err != nil {
return nil, err
}
stateBlock, err := NewSqliteStateBlockTable(d.db)
if err != nil {
return nil, err
}
stateSnapshot, err := NewSqliteStateSnapshotTable(d.db)
if err != nil {
return nil, err
}
d.prevEvents, err = NewSqlitePrevEventsTable(d.db)
if err != nil {
return nil, err
}
roomAliases, err := NewSqliteRoomAliasesTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db,
EventsTable: d.events, EventsTable: d.events,
EventTypesTable: d.eventTypes, EventTypesTable: d.eventTypes,
EventStateKeysTable: d.eventStateKeys, EventStateKeysTable: d.eventStateKeys,
EventJSONTable: d.eventJSON, EventJSONTable: d.eventJSON,
RoomsTable: d.rooms,
TransactionsTable: d.transactions,
StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot,
PrevEventsTable: d.prevEvents,
RoomAliasesTable: roomAliases,
} }
return &d, nil return &d, nil
} }
// StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
eventStateKeyNID types.EventStateKeyNID
eventNID types.EventNID
stateNID types.StateSnapshotNID
err error
)
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
if txnAndSessionID != nil {
if err = d.statements.insertTransaction(
ctx, txn, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return err
}
}
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
// Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
// Note that the below logic depends on the m.room.create event being the
// first event that is persisted to the database when creating or joining a
// room.
var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return err
}
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil {
return err
}
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
return err
}
eventStateKey := event.StateKey()
// Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
return err
}
}
if eventNID, stateNID, err = d.events.InsertEvent(
ctx,
txn,
roomNID,
eventTypeNID,
eventStateKeyNID,
event.EventID(),
event.EventReference().EventSHA256,
authEventNIDs,
event.Depth(),
); err != nil {
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.events.SelectEvent(ctx, txn, event.EventID())
}
if err != nil {
return err
}
}
if err = d.eventJSON.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return err
}
return nil
})
if err != nil {
return 0, types.StateAtEvent{}, err
}
return roomNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID,
StateEntry: types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypeNID,
EventStateKeyNID: eventStateKeyNID,
},
EventNID: eventNID,
},
}, nil
}
func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) (
gomatrixserverlib.RoomVersion, error,
) {
var err error
var roomVersion gomatrixserverlib.RoomVersion
// Look for m.room.create events.
if event.Type() != gomatrixserverlib.MRoomCreate {
return gomatrixserverlib.RoomVersion(""), nil
}
roomVersion = gomatrixserverlib.RoomVersionV1
var createContent gomatrixserverlib.CreateContent
// The m.room.create event contains an optional "room_version" key in
// the event content, so we need to unmarshal that first.
if err = json.Unmarshal(event.Content(), &createContent); err != nil {
return gomatrixserverlib.RoomVersion(""), err
}
// A room version was specified in the event content?
if createContent.RoomVersion != nil {
roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion)
}
return roomVersion, err
}
func (d *Database) assignRoomNID( func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (roomNID types.RoomNID, err error) { ) (roomNID types.RoomNID, err error) {
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) roomNID, err = d.rooms.SelectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database. // We don't have a numeric ID so insert one into the database.
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) roomNID, err = d.rooms.InsertRoomNID(ctx, txn, roomID, roomVersion)
if err == nil { if err == nil {
// Now get the numeric ID back out of the database // Now get the numeric ID back out of the database
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) roomNID, err = d.rooms.SelectRoomNID(ctx, txn, roomID)
}
}
return
}
func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (eventTypeNID types.EventTypeNID, err error) {
// Check if we already have a numeric ID in the database.
eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventTypeNID, err = d.eventTypes.InsertEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType)
} }
} }
return return
@ -267,94 +164,6 @@ func (d *Database) assignStateKeyNID(
return return
} }
// Events implements input.EventDatabase
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
var eventJSONs []tables.EventJSONPair
var err error
var results []types.Event
eventJSONs, err = d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs)
if err != nil || len(eventJSONs) == 0 {
return nil, nil
}
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
results = make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
result := &results[i]
result.EventNID = eventJSON.EventNID
roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, txn, eventJSON.EventNID)
if err != nil {
return err
}
roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, txn, roomNID)
if err != nil {
return err
}
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON(
eventJSON.EventJSON, false, roomVersion,
)
if err != nil {
return nil
}
}
return nil
})
if err != nil {
return []types.Event{}, err
}
return results, nil
}
// AddState implements input.EventDatabase
func (d *Database) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
if len(state) > 0 {
var stateBlockNID types.StateBlockNID
stateBlockNID, err = d.statements.bulkInsertStateData(ctx, txn, state)
if err != nil {
return err
}
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
}
stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)
return err
})
if err != nil {
return 0, err
}
return
}
// StateBlockNIDs implements state.RoomStateDatabase
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) (sl []types.StateBlockNIDList, err error) {
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
return err
})
return
}
// StateEntries implements state.RoomStateDatabase
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) (sel []types.StateEntryList, err error) {
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
return err
})
return
}
// GetLatestEventsForUpdate implements input.EventDatabase // GetLatestEventsForUpdate implements input.EventDatabase
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,
@ -364,7 +173,7 @@ func (d *Database) GetLatestEventsForUpdate(
return nil, err return nil, err
} }
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) d.rooms.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
if err != nil { if err != nil {
txn.Rollback() // nolint: errcheck txn.Rollback() // nolint: errcheck
return nil, err return nil, err
@ -396,18 +205,6 @@ func (d *Database) GetLatestEventsForUpdate(
}, nil }, nil
} }
// GetTransactionEventID implements input.EventDatabase
func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string,
sessionID int64, userID string,
) (string, error) {
eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID)
if err == sql.ErrNoRows {
return "", nil
}
return eventID, err
}
type roomRecentEventsUpdater struct { type roomRecentEventsUpdater struct {
transaction transaction
d *Database d *Database
@ -442,7 +239,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
err := internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { err := internal.WithTransaction(u.d.db, func(txn *sql.Tx) error {
for _, ref := range previousEventReferences { for _, ref := range previousEventReferences {
if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { if err := u.d.prevEvents.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return err return err
} }
} }
@ -454,7 +251,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) { func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) {
err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error {
err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256) err := u.d.prevEvents.SelectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil { if err == nil {
res = true res = true
err = nil err = nil
@ -478,7 +275,7 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
for i := range latest { for i := range latest {
eventNIDs[i] = latest[i].EventNID eventNIDs[i] = latest[i].EventNID
} }
return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) return u.d.rooms.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
}) })
return err return err
} }
@ -508,59 +305,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta
return return
} }
// RoomNID implements query.RoomserverQueryAPIDB
func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
roomNID = 0
err = nil
}
return err
})
return
}
// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB
func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
roomNID, err = d.RoomNID(ctx, roomID)
if err != nil {
return
}
latestEvents, _, err := d.statements.selectLatestEventNIDs(ctx, nil, roomNID)
if err != nil {
return
}
if len(latestEvents) == 0 {
roomNID = 0
return
}
return
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
var eventNIDs []types.EventNID
eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID)
if err != nil {
return err
}
references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs)
if err != nil {
return err
}
depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs)
if err != nil {
return err
}
return nil
})
return
}
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase // GetInvitesForUser implements query.RoomserverQueryAPIDatabase
func (d *Database) GetInvitesForUser( func (d *Database) GetInvitesForUser(
ctx context.Context, ctx context.Context,
@ -570,44 +314,6 @@ func (d *Database) GetInvitesForUser(
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
} }
// SetRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID)
}
// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.statements.selectRoomIDFromAlias(ctx, nil, alias)
}
// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.statements.selectAliasesFromRoomID(ctx, nil, roomID)
}
// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
return d.statements.selectCreatorIDFromAlias(ctx, nil, alias)
}
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
return d.statements.deleteRoomAlias(ctx, nil, alias)
}
// StateEntriesForTuples implements state.RoomStateDatabase
func (d *Database) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.statements.bulkSelectFilteredStateBlockEntries(
ctx, nil, stateBlockNIDs, stateKeyTuples,
)
}
// MembershipUpdater implements input.RoomEventDatabase // MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string, ctx context.Context, roomID, targetUserID string,
@ -844,37 +550,6 @@ func (d *Database) GetMembershipEventNIDsForRoom(
return return
} }
// EventsFromIDs implements query.RoomserverQueryAPIEventDB
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
var nids []types.EventNID
for _, nid := range nidMap {
nids = append(nids, nid)
}
return d.Events(ctx, nids)
}
func (d *Database) GetRoomVersionForRoom(
ctx context.Context, roomID string,
) (gomatrixserverlib.RoomVersion, error) {
return d.statements.selectRoomVersionForRoomID(
ctx, nil, roomID,
)
}
func (d *Database) GetRoomVersionForRoomNID(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) {
return d.statements.selectRoomVersionForRoomNID(
ctx, nil, roomNID,
)
}
type transaction struct { type transaction struct {
ctx context.Context ctx context.Context
txn *sql.Tx txn *sql.Tx

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
const transactionsSchema = ` const transactionsSchema = `
@ -46,19 +47,20 @@ type transactionStatements struct {
selectTransactionEventIDStmt *sql.Stmt selectTransactionEventIDStmt *sql.Stmt
} }
func (s *transactionStatements) prepare(db *sql.DB) (err error) { func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
_, err = db.Exec(transactionsSchema) s := &transactionStatements{}
_, err := db.Exec(transactionsSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertTransactionStmt, insertTransactionSQL}, {&s.insertTransactionStmt, insertTransactionSQL},
{&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
}.prepare(db) }.prepare(db)
} }
func (s *transactionStatements) insertTransaction( func (s *transactionStatements) InsertTransaction(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
transactionID string, transactionID string,
sessionID int64, sessionID int64,
@ -72,14 +74,13 @@ func (s *transactionStatements) insertTransaction(
return return
} }
func (s *transactionStatements) selectTransactionEventID( func (s *transactionStatements) SelectTransactionEventID(
ctx context.Context, txn *sql.Tx, ctx context.Context,
transactionID string, transactionID string,
sessionID int64, sessionID int64,
userID string, userID string,
) (eventID string, err error) { ) (eventID string, err error) {
stmt := internal.TxStmt(txn, s.selectTransactionEventIDStmt) err = s.selectTransactionEventIDStmt.QueryRowContext(
err = stmt.QueryRowContext(
ctx, transactionID, sessionID, userID, ctx, transactionID, sessionID, userID,
).Scan(&eventID) ).Scan(&eventID)
return return

View file

@ -53,5 +53,46 @@ type Events interface {
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDForEventNID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (roomNID types.RoomNID, err error) SelectRoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (roomNID types.RoomNID, err error)
}
type Rooms interface {
InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error)
SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error)
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error)
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
}
type Transactions interface {
InsertTransaction(ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, eventID string) error
SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error)
}
type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
}
type StateBlock interface {
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries []types.StateEntry) (types.StateBlockNID, error)
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
}
type RoomAliases interface {
InsertRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) (err error)
SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error)
SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error)
DeleteRoomAlias(ctx context.Context, alias string) (err error)
}
type PreviousEvents interface {
InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error
// Check if the event reference exists
// Returns sql.ErrNoRows if the event reference doesn't exist.
SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error
} }