Compare commits

...

11 commits

Author SHA1 Message Date
Matthew Hodgson b265937559 switch to sqlite3_js 2020-01-19 03:55:06 +00:00
Neil Alexander a4fdb6fe2d Rethink transactions a lot 2020-01-14 18:46:42 +00:00
Neil Alexander 49ff4d6d8f All sorts of debugging information and tweaks - still no joy 2020-01-14 15:17:14 +00:00
Neil Alexander 7a7be4f0ac Fix state block schema a bit 2020-01-13 14:53:03 +00:00
Neil Alexander 8bb8642560 Separate out INSERT/SELECT statements in place of RETURNING in SQLite 2020-01-10 18:28:07 +00:00
Neil Alexander 3852f5c714 Merge branch 'master' into neilalexander/sqlite-roomserver 2020-01-10 12:31:22 +00:00
Neil Alexander 30057b399b Fix typo 2020-01-09 17:44:21 +00:00
Neil Alexander c40850d218 Merge branch 'master' into neilalexander/sqlite-roomserver 2020-01-09 17:21:31 +00:00
Neil Alexander 6c64aa0685 Merge branch 'master' into neilalexander/sqlite-roomserver 2020-01-09 17:11:36 +00:00
Neil Alexander 7a71a59dc7 Merge branch 'master' into neilalexander/sqlite-roomserver 2020-01-09 17:05:44 +00:00
Neil Alexander c7000f343e Some SQLite support for roomserver 2020-01-09 16:50:11 +00:00
23 changed files with 3103 additions and 9 deletions

2
go.mod
View file

@ -13,10 +13,12 @@ require (
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
github.com/matrix-org/go-sqlite3-js v0.0.0-20200119033421-beabc8946bf7
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20200110113524-5f9a44f2fc67 github.com/matrix-org/gomatrixserverlib v0.0.0-20200110113524-5f9a44f2fc67
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5
github.com/mattn/go-sqlite3 v2.0.2+incompatible
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5
github.com/opentracing/opentracing-go v0.0.0-20170806192116-8ebe5d4e236e github.com/opentracing/opentracing-go v0.0.0-20170806192116-8ebe5d4e236e
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac // indirect github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac // indirect

4
go.sum
View file

@ -61,6 +61,8 @@ github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d h1:Hdtccv31GWxWoCzWsIhZXy5N
github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 h1:nMX2t7hbGF0NYDYySx0pCqEKGKAeZIiSqlWSspetlhY= github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 h1:nMX2t7hbGF0NYDYySx0pCqEKGKAeZIiSqlWSspetlhY=
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
github.com/matrix-org/go-sqlite3-js v0.0.0-20200119033421-beabc8946bf7 h1:TfoCHWdbtAzwC3ML0LUqOLZPPML9oB1BggattWfp0xs=
github.com/matrix-org/go-sqlite3-js v0.0.0-20200119033421-beabc8946bf7/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af h1:piaIBNQGIHnni27xRB7VKkEwoWCgAmeuYf8pxAyG0bI= github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af h1:piaIBNQGIHnni27xRB7VKkEwoWCgAmeuYf8pxAyG0bI=
github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4=
@ -71,6 +73,8 @@ github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A= github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8= github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8=
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA= github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA=
github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U=
github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0= github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0=

View file

@ -16,6 +16,7 @@ package input
import ( import (
"context" "context"
"fmt"
"sort" "sort"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -35,16 +36,19 @@ func checkAuthEvents(
if err != nil { if err != nil {
return nil, err return nil, err
} }
fmt.Println("authStateEntries:", authStateEntries)
// TODO: check for duplicate state keys here. // TODO: check for duplicate state keys here.
// Work out which of the state events we actually need. // Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
fmt.Println("stateNeeded:", stateNeeded)
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fmt.Println("authEvents:", authEvents)
// Check if the event is allowed. // Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {

View file

@ -94,6 +94,7 @@ func processRoomEvent(
// Check that the event passes authentication checks and work out the numeric IDs for the auth events. // Check that the event passes authentication checks and work out the numeric IDs for the auth events.
authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs) authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs)
if err != nil { if err != nil {
fmt.Println("failed checkAuthEvents:", err)
return return
} }
@ -104,6 +105,7 @@ func processRoomEvent(
) )
// On error OR event with the transaction already processed/processesing // On error OR event with the transaction already processed/processesing
if err != nil || eventID != "" { if err != nil || eventID != "" {
fmt.Println("failed GetTransactionEventID:", err)
return return
} }
} }
@ -111,6 +113,7 @@ func processRoomEvent(
// Store the event // Store the event
roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
if err != nil { if err != nil {
fmt.Println("failed StoreEvent:", err)
return return
} }
@ -118,6 +121,7 @@ func processRoomEvent(
// For outliers we can stop after we've stored the event itself as it // For outliers we can stop after we've stored the event itself as it
// doesn't have any associated state to store and we don't need to // doesn't have any associated state to store and we don't need to
// notify anyone about it. // notify anyone about it.
fmt.Println("kind is outlier")
return event.EventID(), nil return event.EventID(), nil
} }
@ -126,6 +130,7 @@ func processRoomEvent(
// Lets calculate one. // Lets calculate one.
err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event) err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event)
if err != nil { if err != nil {
fmt.Println("failed to calculate and set state")
return return
} }
} }
@ -155,15 +160,18 @@ func calculateAndSetState(
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
fmt.Println("failed StateEntriesForEventIDs")
return err return err
} }
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil {
fmt.Println("failed AddState")
return err return err
} }
} else { } else {
// We haven't been told what the state at the event is so we need to calculate it from the prev_events // We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
fmt.Println("Failed CalculateAndStoreStateBeforeEvent")
return err return err
} }
} }

View file

@ -558,6 +558,7 @@ func CalculateAndStoreStateBeforeEvent(
prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs) prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
if err != nil { if err != nil {
fmt.Println("Failed stateAtEventIDs", err)
return 0, err return 0, err
} }
@ -579,6 +580,7 @@ func CalculateAndStoreStateAfterEvents(
// 2) There weren't any prev_events for this event so the state is // 2) There weren't any prev_events for this event so the state is
// empty. // empty.
metrics.algorithm = "empty_state" metrics.algorithm = "empty_state"
fmt.Println("there were't any prev_events!")
return metrics.stop(db.AddState(ctx, roomNID, nil, nil)) return metrics.stop(db.AddState(ctx, roomNID, nil, nil))
} }
@ -590,6 +592,7 @@ func CalculateAndStoreStateAfterEvents(
// as the previous events. // as the previous events.
// This should be the common case. // This should be the common case.
metrics.algorithm = "no_change" metrics.algorithm = "no_change"
fmt.Println("none of the previous events were state events")
return metrics.stop(prevState.BeforeStateSnapshotNID, nil) return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
} }
// The previous event was a state event so we need to store a copy // The previous event was a state event so we need to store a copy
@ -599,6 +602,7 @@ func CalculateAndStoreStateAfterEvents(
) )
if err != nil { if err != nil {
metrics.algorithm = "_load_state_blocks" metrics.algorithm = "_load_state_blocks"
fmt.Println("failed StateBlockNIDs", err)
return metrics.stop(0, err) return metrics.stop(0, err)
} }
stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
@ -639,6 +643,7 @@ func calculateAndStoreStateAfterManyEvents(
calculateStateAfterManyEvents(ctx, db, prevStates) calculateStateAfterManyEvents(ctx, db, prevStates)
metrics.algorithm = algorithm metrics.algorithm = algorithm
if err != nil { if err != nil {
fmt.Println("failed calculateStateAfterManyEvents", err)
return metrics.stop(0, err) return metrics.stop(0, err)
} }
@ -658,6 +663,7 @@ func calculateStateAfterManyEvents(
combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates) combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates)
if err != nil { if err != nil {
algorithm = "_load_combined_state" algorithm = "_load_combined_state"
fmt.Println("failed LoadCombinedStateAfterEvents")
return return
} }
@ -688,6 +694,7 @@ func calculateStateAfterManyEvents(
resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts) resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts)
if err != nil { if err != nil {
algorithm = "_resolve_conflicts" algorithm = "_resolve_conflicts"
fmt.Println("failed resolveConflicts", err)
return return
} }
algorithm = "full_state_with_conflicts" algorithm = "full_state_with_conflicts"

View file

@ -0,0 +1,102 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
const eventJSONSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_json (
event_nid INTEGER NOT NULL PRIMARY KEY,
event_json TEXT NOT NULL
);
`
const insertEventJSONSQL = `
INSERT INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
ON CONFLICT DO NOTHING
`
// Bulk event JSON lookup by numeric event ID.
// Sort by the numeric event ID.
// This means that we can use binary search to lookup by numeric event ID.
const bulkSelectEventJSONSQL = `
SELECT event_nid, event_json FROM roomserver_event_json
WHERE event_nid IN ($1)
ORDER BY event_nid ASC
`
type eventJSONStatements struct {
insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt
}
func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(eventJSONSchema)
if err != nil {
return
}
return statementList{
{&s.insertEventJSONStmt, insertEventJSONSQL},
{&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
}.prepare(db)
}
func (s *eventJSONStatements) insertEventJSON(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
) error {
_, err := common.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
return err
}
type eventJSONPair struct {
EventNID types.EventNID
EventJSON []byte
}
func (s *eventJSONStatements) bulkSelectEventJSON(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]eventJSONPair, error) {
rows, err := common.TxStmt(txn, s.bulkSelectEventJSONStmt).QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
// We know that we will only get as many results as event NIDs
// because of the unique constraint on event NIDs.
// So we can allocate an array of the correct size now.
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
results := make([]eventJSONPair, len(eventNIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
var eventNID int64
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
fmt.Println("bulkSelectEventJSON rows.Scan:", err)
return nil, err
}
result.EventNID = types.EventNID(eventNID)
}
return results[:i], nil
}

View file

@ -0,0 +1,168 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
const eventStateKeysSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
event_state_key TEXT NOT NULL UNIQUE
);
INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
VALUES (1, '')
ON CONFLICT DO NOTHING;
`
// Same as insertEventTypeNIDSQL
const insertEventStateKeyNIDSQL = `
INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
ON CONFLICT DO NOTHING;
`
const insertEventStateKeyNIDResultSQL = `
SELECT event_state_key_nid FROM roomserver_event_state_keys
WHERE rowid = last_insert_rowid();
`
const selectEventStateKeyNIDSQL = `
SELECT event_state_key_nid FROM roomserver_event_state_keys
WHERE event_state_key = $1
`
// Bulk lookup from string state key to numeric ID for that state key.
// Takes an array of strings as the query parameter.
const bulkSelectEventStateKeyNIDSQL = `
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
WHERE event_state_key IN ($1)
`
// Bulk lookup from numeric ID to string state key for that state key.
// Takes an array of strings as the query parameter.
const bulkSelectEventStateKeySQL = `
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
WHERE event_state_key_nid IN ($1)
`
type eventStateKeyStatements struct {
insertEventStateKeyNIDStmt *sql.Stmt
insertEventStateKeyNIDResultStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyStmt *sql.Stmt
}
func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(eventStateKeysSchema)
if err != nil {
return
}
return statementList{
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
{&s.insertEventStateKeyNIDResultStmt, insertEventStateKeyNIDResultSQL},
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
}.prepare(db)
}
func (s *eventStateKeyStatements) insertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
var err error
insertStmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt)
selectStmt := common.TxStmt(txn, s.insertEventStateKeyNIDResultStmt)
if _, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil {
err = selectStmt.QueryRowContext(ctx).Scan(&eventStateKeyNID)
if err != nil {
fmt.Println("insertEventStateKeyNID selectStmt.QueryRowContext:", err)
}
} else {
fmt.Println("insertEventStateKeyNID insertStmt.ExecContext:", err)
}
return types.EventStateKeyNID(eventStateKeyNID), err
}
func (s *eventStateKeyStatements) selectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
if err != nil {
fmt.Println("selectEventStateKeyNID stmt.QueryRowContext:", err)
}
return types.EventStateKeyNID(eventStateKeyNID), err
}
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt).QueryContext(
ctx, sqliteInStr(pq.StringArray(eventStateKeys)),
)
if err != nil {
fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
for rows.Next() {
var stateKey string
var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
fmt.Println("bulkSelectEventStateKeyNID rows.Scan:", err)
return nil, err
}
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
}
return result, nil
}
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i])
}
rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyStmt).QueryContext(ctx, nIDs)
if err != nil {
fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
for rows.Next() {
var stateKey string
var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
fmt.Println("bulkSelectEventStateKey rows.Scan:", err)
return nil, err
}
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
}
return result, nil
}

View file

@ -0,0 +1,139 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
const eventTypesSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_types (
event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT,
event_type TEXT NOT NULL UNIQUE
);
INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
(1, 'm.room.create'),
(2, 'm.room.power_levels'),
(3, 'm.room.join_rules'),
(4, 'm.room.third_party_invite'),
(5, 'm.room.member'),
(6, 'm.room.redaction'),
(7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
`
// Assign a new numeric event type ID.
// The usual case is that the event type is not in the database.
// In that case the ID will be assigned using the next value from the sequence.
// We use `RETURNING` to tell postgres to return the assigned ID.
// But it's possible that the type was added in a query that raced with us.
// This will result in a conflict on the event_type_unique constraint, in this
// case we do nothing. Postgresql won't return a row in that case so we rely on
// the caller catching the sql.ErrNoRows error and running a select to get the row.
// We could get postgresql to return the row on a conflict by updating the row
// but it doesn't seem like a good idea to modify the rows just to make postgresql
// return it. Modifying the rows will cause postgres to assign a new tuple for the
// row even though the data doesn't change resulting in unncesssary modifications
// to the indexes.
const insertEventTypeNIDSQL = `
INSERT INTO roomserver_event_types (event_type) VALUES ($1)
ON CONFLICT DO NOTHING;
`
const insertEventTypeNIDResultSQL = `
SELECT event_type_nid FROM roomserver_event_types
WHERE rowid = last_insert_rowid();
`
const selectEventTypeNIDSQL = `
SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
`
// Bulk lookup from string event type to numeric ID for that event type.
// Takes an array of strings as the query parameter.
const bulkSelectEventTypeNIDSQL = `
SELECT event_type, event_type_nid FROM roomserver_event_types
WHERE event_type IN ($1)
`
type eventTypeStatements struct {
insertEventTypeNIDStmt *sql.Stmt
insertEventTypeNIDResultStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt
bulkSelectEventTypeNIDStmt *sql.Stmt
}
func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(eventTypesSchema)
if err != nil {
return
}
return statementList{
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
{&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
{&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
{&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL},
}.prepare(db)
}
func (s *eventTypeStatements) insertEventTypeNID(
ctx context.Context, tx *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
var err error
insertStmt := common.TxStmt(tx, s.insertEventTypeNIDStmt)
resultStmt := common.TxStmt(tx, s.insertEventTypeNIDResultStmt)
if _, err = insertStmt.ExecContext(ctx, eventType); err == nil {
err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
}
return types.EventTypeNID(eventTypeNID), err
}
func (s *eventTypeStatements) selectEventTypeNID(
ctx context.Context, tx *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
selectStmt := common.TxStmt(tx, s.selectEventTypeNIDStmt)
err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err
}
func (s *eventTypeStatements) bulkSelectEventTypeNID(
ctx context.Context, tx *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
selectStmt := common.TxStmt(tx, s.bulkSelectEventTypeNIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventTypes)))
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
result := make(map[string]types.EventTypeNID, len(eventTypes))
for rows.Next() {
var eventType string
var eventTypeNID int64
if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
return nil, err
}
result[eventType] = types.EventTypeNID(eventTypeNID)
}
return result, nil
}

View file

@ -0,0 +1,428 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
const eventsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_events (
event_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_nid INTEGER NOT NULL,
event_type_nid INTEGER NOT NULL,
event_state_key_nid INTEGER NOT NULL,
sent_to_output BOOLEAN NOT NULL DEFAULT FALSE,
state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
depth INTEGER NOT NULL,
event_id TEXT NOT NULL UNIQUE,
reference_sha256 BLOB NOT NULL,
auth_event_nids TEXT NOT NULL DEFAULT '{}'
);
`
const insertEventSQL = `
INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT DO NOTHING;
`
const insertEventResultSQL = `
SELECT event_nid, state_snapshot_nid FROM roomserver_events
WHERE rowid = last_insert_rowid();
`
const selectEventSQL = "" +
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
// Bulk lookup of events by string ID.
// Sort by the numeric IDs for event type and state key.
// This means we can use binary search to lookup entries by type and state key.
const bulkSelectStateEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
" WHERE event_id IN ($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC"
const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" +
" WHERE event_id IN ($1)"
const updateEventStateSQL = "" +
"UPDATE roomserver_events SET state_snapshot_nid = $2 WHERE event_nid = $1"
const selectEventSentToOutputSQL = "" +
"SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1"
const updateEventSentToOutputSQL = "" +
"UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1"
const selectEventIDSQL = "" +
"SELECT event_id FROM roomserver_events WHERE event_nid = $1"
const bulkSelectStateAtEventAndReferenceSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
" FROM roomserver_events WHERE event_nid IN ($1)"
const bulkSelectEventReferenceSQL = "" +
"SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)"
const bulkSelectEventIDSQL = "" +
"SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)"
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
type eventStatements struct {
insertEventStmt *sql.Stmt
insertEventResultStmt *sql.Stmt
selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt
updateEventStateStmt *sql.Stmt
selectEventSentToOutputStmt *sql.Stmt
updateEventSentToOutputStmt *sql.Stmt
selectEventIDStmt *sql.Stmt
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
}
func (s *eventStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(eventsSchema)
if err != nil {
return
}
return statementList{
{&s.insertEventStmt, insertEventSQL},
{&s.insertEventResultStmt, insertEventResultSQL},
{&s.selectEventStmt, selectEventSQL},
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
{&s.updateEventStateStmt, updateEventStateSQL},
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
{&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
{&s.selectEventIDStmt, selectEventIDSQL},
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
}.prepare(db)
}
func (s *eventStatements) insertEvent(
ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
eventTypeNID types.EventTypeNID,
eventStateKeyNID types.EventStateKeyNID,
eventID string,
referenceSHA256 []byte,
authEventNIDs []types.EventNID,
depth int64,
) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64
var stateNID int64
var err error
insertStmt := common.TxStmt(txn, s.insertEventStmt)
resultStmt := common.TxStmt(txn, s.insertEventResultStmt)
if _, err = insertStmt.ExecContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
); err == nil {
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
}
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
}
func (s *eventStatements) selectEvent(
ctx context.Context, txn *sql.Tx, eventID string,
) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64
var stateNID int64
selectStmt := common.TxStmt(txn, s.selectEventStmt)
err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
}
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) bulkSelectStateEventByID(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
// We know that we will only get as many results as event IDs
// because of the unique constraint on event IDs.
// So we can allocate an array of the correct size now.
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
results := make([]types.StateEntry, len(eventIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(
&result.EventTypeNID,
&result.EventStateKeyNID,
&result.EventNID,
); err != nil {
return nil, err
}
}
if i != len(eventIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query.
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
// If this turns out to be impossible and we do need the debug information here, it would be better
// to do it as a separate query rather than slowing down/complicating the common case.
return nil, types.MissingEventError(
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
)
}
return results, err
}
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) bulkSelectStateAtEventByID(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make([]types.StateAtEvent, len(eventIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(
&result.EventTypeNID,
&result.EventStateKeyNID,
&result.EventNID,
&result.BeforeStateSnapshotNID,
); err != nil {
return nil, err
}
if result.BeforeStateSnapshotNID == 0 {
return nil, types.MissingEventError(
fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
)
}
}
if i != len(eventIDs) {
return nil, types.MissingEventError(
fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
)
}
return results, err
}
func (s *eventStatements) updateEventState(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
updateStmt := common.TxStmt(txn, s.updateEventStateStmt)
_, err := updateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
if err != nil {
fmt.Println("updateEventState s.updateEventStateStmt.ExecContext:", err)
}
return err
}
func (s *eventStatements) selectEventSentToOutput(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (sentToOutput bool, err error) {
selectStmt := common.TxStmt(txn, s.selectEventSentToOutputStmt)
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
//err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
if err != nil {
fmt.Println("selectEventSentToOutput stmt.QueryRowContext:", err)
}
return
}
func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
updateStmt := common.TxStmt(txn, s.updateEventSentToOutputStmt)
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
//_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID))
if err != nil {
fmt.Println("updateEventSentToOutput stmt.QueryRowContext:", err)
}
return err
}
func (s *eventStatements) selectEventID(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (eventID string, err error) {
selectStmt := common.TxStmt(txn, s.selectEventIDStmt)
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID)
if err != nil {
fmt.Println("selectEventID stmt.QueryRowContext:", err)
}
return
}
func (s *eventStatements) bulkSelectStateAtEventAndReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.StateAtEventAndReference, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
if err != nil {
fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make([]types.StateAtEventAndReference, len(eventNIDs))
i := 0
for ; rows.Next(); i++ {
var (
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
stateSnapshotNID int64
eventID string
eventSHA256 []byte
)
if err = rows.Scan(
&eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256,
); err != nil {
fmt.Println("bulkSelectStateAtEventAndReference rows.Scan:", err)
return nil, err
}
result := &results[i]
result.EventTypeNID = types.EventTypeNID(eventTypeNID)
result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
result.EventNID = types.EventNID(eventNID)
result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID)
result.EventID = eventID
result.EventSHA256 = eventSHA256
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
return results, nil
}
func (s *eventStatements) bulkSelectEventReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.EventReference, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectEventReferenceStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
if err != nil {
fmt.Println("bulkSelectEventReference s.bulkSelectEventReferenceStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make([]gomatrixserverlib.EventReference, len(eventNIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil {
fmt.Println("bulkSelectEventReference rows.Scan:", err)
return nil, err
}
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
return results, nil
}
// bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectEventIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
if err != nil {
fmt.Println("bulkSelectEventID s.bulkSelectEventIDStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make(map[types.EventNID]string, len(eventNIDs))
i := 0
for ; rows.Next(); i++ {
var eventNID int64
var eventID string
if err = rows.Scan(&eventNID, &eventID); err != nil {
fmt.Println("bulkSelectEventID rows.Scan:", err)
return nil, err
}
results[types.EventNID(eventNID)] = eventID
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
return results, nil
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectEventNIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
if err != nil {
fmt.Println("bulkSelectEventNID s.bulkSelectEventNIDStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make(map[string]types.EventNID, len(eventIDs))
for rows.Next() {
var eventID string
var eventNID int64
if err = rows.Scan(&eventID, &eventNID); err != nil {
fmt.Println("bulkSelectEventNID rows.Scan:", err)
return nil, err
}
results[eventID] = types.EventNID(eventNID)
}
return results, nil
}
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) {
var result int64
selectStmt := common.TxStmt(txn, s.selectMaxEventDepthStmt)
err := selectStmt.QueryRowContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))).Scan(&result)
if err != nil {
fmt.Println("selectMaxEventDepth stmt.QueryRowContext:", err)
return 0, err
}
return result, nil
}
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
nids := make([]int64, len(eventNIDs))
for i := range eventNIDs {
nids[i] = int64(eventNIDs[i])
}
return nids
}

View file

@ -0,0 +1,149 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
const inviteSchema = `
CREATE TABLE IF NOT EXISTS roomserver_invites (
invite_event_id TEXT PRIMARY KEY,
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
retired BOOLEAN NOT NULL DEFAULT FALSE,
invite_event_json TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
WHERE NOT retired;
`
const insertInviteEventSQL = "" +
"INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
" sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT DO NOTHING"
const selectInviteActiveForUserInRoomSQL = "" +
"SELECT sender_nid FROM roomserver_invites" +
" WHERE target_nid = $1 AND room_nid = $2" +
" AND NOT retired"
// Retire every active invite for a user in a room.
// Ideally we'd know which invite events were retired by a given update so we
// wouldn't need to remove every active invite.
// However the matrix protocol doesn't give us a way to reliably identify the
// invites that were retired, so we are forced to retire all of them.
const updateInviteRetiredSQL = `
UPDATE roomserver_invites SET retired = TRUE
WHERE room_nid = $1 AND target_nid = $2 AND NOT retired;
SELECT invite_event_id FROM roomserver_invites
WHERE rowid = last_insert_rowid();
`
type inviteStatements struct {
insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt
updateInviteRetiredStmt *sql.Stmt
}
func (s *inviteStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(inviteSchema)
if err != nil {
return
}
return statementList{
{&s.insertInviteEventStmt, insertInviteEventSQL},
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
}.prepare(db)
}
func (s *inviteStatements) insertInviteEvent(
ctx context.Context,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
stmt := common.TxStmt(txn, s.insertInviteEventStmt)
defer stmt.Close()
result, err := stmt.ExecContext(
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
)
if err != nil {
fmt.Println("insertInviteEvent common.TxStmt.ExecContext:", err)
return false, err
}
count, err := result.RowsAffected()
if err != nil {
fmt.Println("insertInviteEvent result.RowsAffected:", err)
return false, err
}
return count != 0, nil
}
func (s *inviteStatements) updateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil {
fmt.Println("updateInviteRetired stmt.QueryContext:", err)
return nil, err
}
defer (func() { err = rows.Close() })()
for rows.Next() {
var inviteEventID string
if err := rows.Scan(&inviteEventID); err != nil {
fmt.Println("updateInviteRetired rows.Scan:", err)
return nil, err
}
eventIDs = append(eventIDs, inviteEventID)
}
return
}
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) selectInviteActiveForUserInRoom(
ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
ctx, targetUserNID, roomNID,
)
if err != nil {
fmt.Println("selectInviteActiveForUserInRoom s.selectInviteActiveForUserInRoomStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
var result []types.EventStateKeyNID
for rows.Next() {
var senderUserNID int64
if err := rows.Scan(&senderUserNID); err != nil {
fmt.Println("selectInviteActiveForUserInRoom rows.Scan:", err)
return nil, err
}
result = append(result, types.EventStateKeyNID(senderUserNID))
}
return result, nil
}

View file

@ -0,0 +1,22 @@
package sqlite3
import (
"strconv"
"strings"
"github.com/lib/pq"
)
type SqliteList string
func sqliteIn(a pq.Int64Array) string {
var b []string
for _, n := range a {
b = append(b, strconv.FormatInt(n, 10))
}
return strings.Join(b, ",")
}
func sqliteInStr(a pq.StringArray) string {
return "\"" + strings.Join(a, "\",\"") + "\""
}

View file

@ -0,0 +1,195 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
type membershipState int64
const (
membershipStateLeaveOrBan membershipState = 1
membershipStateInvite membershipState = 2
membershipStateJoin membershipState = 3
)
const membershipSchema = `
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
UNIQUE (room_nid, target_nid)
);
`
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid)" +
" VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2"
const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2"
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1"
const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2"
const updateMembershipSQL = "" +
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
" WHERE room_nid = $1 AND target_nid = $2"
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(membershipSchema)
if err != nil {
return
}
return statementList{
{&s.insertMembershipStmt, insertMembershipSQL},
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db)
}
func (s *membershipStatements) insertMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error {
stmt := common.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
if err != nil {
fmt.Println("insertMembership stmt.ExecContent:", err)
}
return err
}
func (s *membershipStatements) selectMembershipForUpdate(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership membershipState, err error) {
stmt := common.TxStmt(txn, s.selectMembershipForUpdateStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership)
if err != nil {
fmt.Println("selectMembershipForUpdate common.TxStmt.Scan:", err)
}
return
}
func (s *membershipStatements) selectMembershipFromRoomAndTarget(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership membershipState, err error) {
selectStmt := common.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = selectStmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID)
if err != nil {
fmt.Println("selectMembershipForUpdate s.selectMembershipFromRoomAndTargetStmt.QueryRowContext:", err)
}
return
}
func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
) (eventNIDs []types.EventNID, err error) {
selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil {
fmt.Println("selectMembershipsFromRoom s.selectMembershipsFromRoomStmt.QueryContext:", err)
return
}
for rows.Next() {
var eNID types.EventNID
if err = rows.Scan(&eNID); err != nil {
fmt.Println("selectMembershipsFromRoom rows.Scan:", err)
return
}
eventNIDs = append(eventNIDs, eNID)
}
return
}
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership membershipState,
) (eventNIDs []types.EventNID, err error) {
stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
fmt.Println("selectMembershipsFromRoomAndMembership stmt.QueryContext:", err)
return
}
for rows.Next() {
var eNID types.EventNID
if err = rows.Scan(&eNID); err != nil {
fmt.Println("selectMembershipsFromRoomAndMembership rows.Scan:", err)
return
}
eventNIDs = append(eventNIDs, eNID)
}
return
}
func (s *membershipStatements) updateMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
senderUserNID types.EventStateKeyNID, membership membershipState,
eventNID types.EventNID,
) error {
stmt := common.TxStmt(txn, s.updateMembershipStmt)
_, err := stmt.ExecContext(
ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID,
)
if err != nil {
fmt.Println("updateMembership common.TxStmt.ExecContent:", err)
}
return err
}

View file

@ -0,0 +1,36 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -0,0 +1,99 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
const previousEventSchema = `
CREATE TABLE IF NOT EXISTS roomserver_previous_events (
previous_event_id TEXT NOT NULL,
previous_reference_sha256 BLOB NOT NULL,
event_nids TEXT NOT NULL,
UNIQUE (previous_event_id, previous_reference_sha256)
);
`
// Insert an entry into the previous_events table.
// If there is already an entry indicating that an event references that previous event then
// add the event NID to the list to indicate that this event references that previous event as well.
// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
// The lock is necessary to avoid data races when checking whether an event is already referenced by another event.
const insertPreviousEventSQL = `
INSERT OR REPLACE INTO roomserver_previous_events
(previous_event_id, previous_reference_sha256, event_nids)
VALUES ($1, $2, $3)
`
// Check if the event is referenced by another event in the table.
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
const selectPreviousEventExistsSQL = `
SELECT 1 FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
`
type previousEventStatements struct {
insertPreviousEventStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt
}
func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(previousEventSchema)
if err != nil {
return
}
return statementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.prepare(db)
}
func (s *previousEventStatements) insertPreviousEvent(
ctx context.Context,
txn *sql.Tx,
previousEventID string,
previousEventReferenceSHA256 []byte,
eventNID types.EventNID,
) error {
stmt := common.TxStmt(txn, s.insertPreviousEventStmt)
_, err := stmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
)
if err != nil {
fmt.Println("insertPreviousEvent stmt.ExecContext:", err)
}
return err
}
// Check if the event reference exists
// Returns sql.ErrNoRows if the event reference doesn't exist.
func (s *previousEventStatements) selectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error {
var ok int64
stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt)
defer func() {
fmt.Println("SELECTED PREVIOUS EVENT EXISTS", ok)
}()
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
}

View file

@ -0,0 +1,141 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/common"
)
const roomAliasesSchema = `
CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
alias TEXT NOT NULL PRIMARY KEY,
room_id TEXT NOT NULL,
creator_id TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
`
const insertRoomAliasSQL = `
INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
`
const selectRoomIDFromAliasSQL = `
SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
`
const selectAliasesFromRoomIDSQL = `
SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
`
const selectCreatorIDFromAliasSQL = `
SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
`
const deleteRoomAliasSQL = `
DELETE FROM roomserver_room_aliases WHERE alias = $1
`
type roomAliasesStatements struct {
insertRoomAliasStmt *sql.Stmt
selectRoomIDFromAliasStmt *sql.Stmt
selectAliasesFromRoomIDStmt *sql.Stmt
selectCreatorIDFromAliasStmt *sql.Stmt
deleteRoomAliasStmt *sql.Stmt
}
func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(roomAliasesSchema)
if err != nil {
return
}
return statementList{
{&s.insertRoomAliasStmt, insertRoomAliasSQL},
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
{&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
{&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
}.prepare(db)
}
func (s *roomAliasesStatements) insertRoomAlias(
ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string,
) (err error) {
insertStmt := common.TxStmt(txn, s.insertRoomAliasStmt)
_, err = insertStmt.ExecContext(ctx, alias, roomID, creatorUserID)
if err != nil {
fmt.Println("insertRoomAlias s.insertRoomAliasStmt.ExecContent:", err)
}
return
}
func (s *roomAliasesStatements) selectRoomIDFromAlias(
ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) {
selectStmt := common.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = selectStmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *roomAliasesStatements) selectAliasesFromRoomID(
ctx context.Context, txn *sql.Tx, roomID string,
) (aliases []string, err error) {
aliases = []string{}
selectStmt := common.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := selectStmt.QueryContext(ctx, roomID)
if err != nil {
fmt.Println("selectAliasesFromRoomID s.selectAliasesFromRoomIDStmt.QueryContext:", err)
return
}
for rows.Next() {
var alias string
if err = rows.Scan(&alias); err != nil {
fmt.Println("selectAliasesFromRoomID rows.Scan:", err)
return
}
aliases = append(aliases, alias)
}
return
}
func (s *roomAliasesStatements) selectCreatorIDFromAlias(
ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) {
selectStmt := common.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = selectStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *roomAliasesStatements) deleteRoomAlias(
ctx context.Context, txn *sql.Tx, alias string,
) (err error) {
deleteStmt := common.TxStmt(txn, s.deleteRoomAliasStmt)
_, err = deleteStmt.ExecContext(ctx, alias)
return
}

View file

@ -0,0 +1,172 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
const roomsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_rooms (
room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_id TEXT NOT NULL UNIQUE,
latest_event_nids TEXT NOT NULL DEFAULT '{}',
last_event_sent_nid INTEGER NOT NULL DEFAULT 0,
state_snapshot_nid INTEGER NOT NULL DEFAULT 0
);
`
// Same as insertEventTypeNIDSQL
const insertRoomNIDSQL = `
INSERT INTO roomserver_rooms (room_id) VALUES ($1)
ON CONFLICT DO NOTHING;
`
const insertRoomNIDResultSQL = `
SELECT room_nid FROM roomserver_rooms
WHERE rowid = last_insert_rowid();
`
const selectRoomNIDSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
const selectLatestEventNIDsSQL = "" +
"SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const selectLatestEventNIDsForUpdateSQL = "" +
"SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1"
type roomStatements struct {
insertRoomNIDStmt *sql.Stmt
insertRoomNIDResultStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt
}
func (s *roomStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(roomsSchema)
if err != nil {
return
}
return statementList{
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
{&s.insertRoomNIDResultStmt, insertRoomNIDResultSQL},
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
}.prepare(db)
}
func (s *roomStatements) insertRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64
var err error
insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt)
resultStmt := common.TxStmt(txn, s.insertRoomNIDResultStmt)
if _, err = insertStmt.ExecContext(ctx, roomID); err == nil {
err = resultStmt.QueryRowContext(ctx).Scan(&roomNID)
if err != nil {
fmt.Println("insertRoomNID resultStmt.QueryRowContext:", err)
}
} else {
fmt.Println("insertRoomNID insertStmt.ExecContext:", err)
}
return types.RoomNID(roomNID), err
}
func (s *roomStatements) selectRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64
stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
if err != nil {
fmt.Println("selectRoomNID stmt.QueryRowContext:", err)
}
return types.RoomNID(roomNID), err
}
func (s *roomStatements) selectLatestEventNIDs(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array
var stateSnapshotNID int64
stmt := common.TxStmt(txn, s.selectLatestEventNIDsStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
if err != nil {
fmt.Println("selectLatestEventNIDs stmt.QueryRowContext:", err)
return nil, 0, err
}
eventNIDs := make([]types.EventNID, len(nids))
for i := range nids {
eventNIDs[i] = types.EventNID(nids[i])
}
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
}
func (s *roomStatements) selectLatestEventsNIDsForUpdate(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array
var lastEventSentNID int64
var stateSnapshotNID int64
stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
if err != nil {
fmt.Println("selectLatestEventsNIDsForUpdate stmt.QueryRowContext:", err)
return nil, 0, 0, err
}
eventNIDs := make([]types.EventNID, len(nids))
for i := range nids {
eventNIDs[i] = types.EventNID(nids[i])
}
return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
}
func (s *roomStatements) updateLatestEventNIDs(
ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
eventNIDs []types.EventNID,
lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID,
) error {
stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt)
_, err := stmt.ExecContext(
ctx,
roomNID,
eventNIDsAsArray(eventNIDs),
int64(lastEventSentNID),
int64(stateSnapshotNID),
)
if err != nil {
fmt.Println("updateLatestEventNIDs stmt.ExecContext:", err)
}
return err
}

View file

@ -0,0 +1,60 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
)
type statements struct {
eventTypeStatements
eventStateKeyStatements
roomStatements
eventStatements
eventJSONStatements
stateSnapshotStatements
stateBlockStatements
previousEventStatements
roomAliasesStatements
inviteStatements
membershipStatements
transactionStatements
}
func (s *statements) prepare(db *sql.DB) error {
var err error
for _, prepare := range []func(db *sql.DB) error{
s.eventTypeStatements.prepare,
s.eventStateKeyStatements.prepare,
s.roomStatements.prepare,
s.eventStatements.prepare,
s.eventJSONStatements.prepare,
s.stateSnapshotStatements.prepare,
s.stateBlockStatements.prepare,
s.previousEventStatements.prepare,
s.roomAliasesStatements.prepare,
s.inviteStatements.prepare,
s.membershipStatements.prepare,
s.transactionStatements.prepare,
} {
if err = prepare(db); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,285 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"sort"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
)
const stateDataSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
event_type_nid INTEGER NOT NULL,
event_state_key_nid INTEGER NOT NULL,
event_nid INTEGER NOT NULL,
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
);
`
const insertStateDataSQL = "" +
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
" VALUES ($1, $2, $3, $4)"
const selectNextStateBlockNIDSQL = `
SELECT COALESCE((
SELECT seq+1 AS state_block_nid FROM sqlite_sequence
WHERE name = 'roomserver_state_block'), 0
) AS state_block_nid
`
// Bulk state lookup by numeric state block ID.
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
// This means that all the entries for a given state_block_nid will appear
// together in the list and those entries will sorted by event_type_nid
// and event_state_key_nid. This property makes it easier to merge two
// state data blocks together.
const bulkSelectStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
// Bulk state lookup by numeric state block ID.
// Filters the rows in each block to the requested types and state keys.
// We would like to restrict to particular type state key pairs but we are
// restricted by the query language to pull the cross product of a list
// of types and a list state_keys. So we have to filter the result in the
// application to restrict it to the list of event types and state keys we
// actually wanted.
const bulkSelectFilteredStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
" AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
type stateBlockStatements struct {
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
}
func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(stateDataSchema)
if err != nil {
return
}
return statementList{
{&s.insertStateDataStmt, insertStateDataSQL},
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
}.prepare(db)
}
func (s *stateBlockStatements) bulkInsertStateData(
ctx context.Context, txn *sql.Tx,
stateBlockNID types.StateBlockNID,
entries []types.StateEntry,
) error {
for _, entry := range entries {
_, err := common.TxStmt(txn, s.insertStateDataStmt).ExecContext(
ctx,
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
int64(entry.EventNID),
)
if err != nil {
fmt.Println("bulkInsertStateData s.insertStateDataStmt.ExecContext:", err)
return err
}
}
return nil
}
func (s *stateBlockStatements) selectNextStateBlockNID(
ctx context.Context,
txn *sql.Tx,
) (types.StateBlockNID, error) {
var stateBlockNID int64
selectStmt := common.TxStmt(txn, s.selectNextStateBlockNIDStmt)
err := selectStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
return types.StateBlockNID(stateBlockNID), err
}
func (s *stateBlockStatements) bulkSelectStateBlockEntries(
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids)))
if err != nil {
fmt.Println("bulkSelectStateBlockEntries s.bulkSelectStateBlockEntriesStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make([]types.StateEntryList, len(stateBlockNIDs))
// current is a pointer to the StateEntryList to append the state entries to.
var current *types.StateEntryList
i := 0
for rows.Next() {
var (
stateBlockNID int64
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
fmt.Println("bulkSelectStateBlockEntries rows.Scan:", err)
return nil, err
}
fmt.Println("state block NID", stateBlockNID, "event type NID", eventTypeNID, "event state key NID", eventStateKeyNID, "event NID", eventNID)
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one.
// So we start appending to the next entry in the list.
current = &results[i]
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
i++
}
current.StateEntries = append(current.StateEntries, entry)
}
if i != len(nids) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
}
return results, nil
}
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
ctx context.Context, txn *sql.Tx,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
selectStmt := common.TxStmt(txn, s.bulkSelectFilteredStateBlockEntriesStmt)
rows, err := selectStmt.QueryContext(
ctx,
stateBlockNIDsAsArray(stateBlockNIDs),
eventTypeNIDArray,
sqliteIn(eventStateKeyNIDArray),
)
if err != nil {
fmt.Println("bulkSelectFilteredStateBlockEntries s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
var results []types.StateEntryList
var current types.StateEntryList
for rows.Next() {
var (
stateBlockNID int64
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
fmt.Println("bulkSelectFilteredStateBlockEntries rows.Scan:", err)
return nil, err
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
// We can use binary search here because we sorted the tuples earlier
if !tuples.contains(entry.StateKeyTuple) {
// The select will return the cross product of types and state keys.
// So we need to check if type of the entry is in the list.
continue
}
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one.
// So we append the current entry to the results and start adding to a new one.
// The first time through the loop current will be empty.
if current.StateEntries != nil {
results = append(results, current)
}
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
}
current.StateEntries = append(current.StateEntries, entry)
}
// Add the last entry to the list if it is not empty.
if current.StateEntries != nil {
results = append(results, current)
}
return results, nil
}
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
return pq.Int64Array(nids)
}
type stateKeyTupleSorter []types.StateKeyTuple
func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
eventTypeNIDs = make(pq.Int64Array, len(s))
eventStateKeyNIDs = make(pq.Int64Array, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -0,0 +1,86 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -0,0 +1,124 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
const stateSnapshotSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_nid INTEGER NOT NULL,
state_block_nids TEXT NOT NULL DEFAULT '{}'
);
`
const insertStateSQL = `
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
VALUES ($1, $2);
`
const insertStateResultSQL = `
SELECT state_snapshot_nid FROM roomserver_state_snapshots
WHERE rowid = last_insert_rowid();
`
// Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result
// to lookup the state data NIDs for a state snapshot NID.
const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
type stateSnapshotStatements struct {
insertStateStmt *sql.Stmt
insertStateResultStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(stateSnapshotSchema)
if err != nil {
return
}
return statementList{
{&s.insertStateStmt, insertStateSQL},
{&s.insertStateResultStmt, insertStateResultSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
}.prepare(db)
}
func (s *stateSnapshotStatements) insertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
insertStmt := common.TxStmt(txn, s.insertStateStmt)
resultStmt := common.TxStmt(txn, s.insertStateResultStmt)
if _, err = insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err == nil {
err = resultStmt.QueryRowContext(ctx).Scan(&stateNID)
if err != nil {
fmt.Println("insertState s.insertStateResultStmt.QueryRowContext:", err)
}
} else {
fmt.Println("insertState s.insertStateStmt.ExecContext:", err)
}
return
}
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
}
selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids)))
if err != nil {
fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err)
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
fmt.Println("bulkSelectStateBlockNIDs rows.Scan:", err)
return nil, err
}
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
for k := range stateBlockNIDs {
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
}
}
if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
}
return results, nil
}

View file

@ -0,0 +1,778 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/matrix-org/go-sqlite3-js/sqlite3_js"
)
// A Database is used to store room events and stream offsets.
type Database struct {
statements statements
db *sql.DB
}
// Open a postgres database.
func Open(dataSourceName string) (*Database, error) {
var d Database
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, err
}
var cs string
if uri.Opaque != "" { // file:filename.db
cs = fmt.Sprintf("%s?cache=shared&_busy_timeout=9999999", uri.Opaque)
} else if uri.Path != "" { // file:///path/to/filename.db
cs = fmt.Sprintf("%s?cache=shared&_busy_timeout=9999999", uri.Path)
} else {
return nil, errors.New("no filename or path in connect string")
}
if d.db, err = sql.Open("sqlite3_js", cs); err != nil {
return nil, err
}
d.db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA parser_trace = true;")
//d.db.SetMaxOpenConns(1)
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
return &d, nil
}
// StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
eventStateKeyNID types.EventStateKeyNID
eventNID types.EventNID
stateNID types.StateSnapshotNID
err error
)
if txnAndSessionID != nil {
if err = d.statements.insertTransaction(
ctx, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return 0, types.StateAtEvent{}, err
}
}
err = common.WithTransaction(d.db, func(tx *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, tx, event.RoomID())
return err
})
if err != nil {
return 0, types.StateAtEvent{}, err
}
err = common.WithTransaction(d.db, func(tx *sql.Tx) error {
eventTypeNID, err = d.assignEventTypeNID(ctx, tx, event.Type())
return err
})
if err != nil {
return 0, types.StateAtEvent{}, err
}
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
eventStateKey := event.StateKey()
// Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
return err
}
}
if eventNID, stateNID, err = d.statements.insertEvent(
ctx,
txn,
roomNID,
eventTypeNID,
eventStateKeyNID,
event.EventID(),
event.EventReference().EventSHA256,
authEventNIDs,
event.Depth(),
); err != nil {
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID())
}
if err != nil {
return err
}
}
if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return err
}
return nil
})
if err != nil {
return 0, types.StateAtEvent{}, err
}
return roomNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID,
StateEntry: types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypeNID,
EventStateKeyNID: eventStateKeyNID,
},
EventNID: eventNID,
},
}, nil
}
func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (roomNID types.RoomNID, err error) {
// Check if we already have a numeric ID in the database.
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
}
}
return
}
func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (eventTypeNID types.EventTypeNID, err error) {
// Check if we already have a numeric ID in the database.
eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType)
}
}
return
}
func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (eventStateKeyNID types.EventStateKeyNID, err error) {
// Check if we already have a numeric ID in the database.
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
}
}
return
}
// StateEntriesForEventIDs implements input.EventDatabase
func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return d.statements.bulkSelectStateEventByID(ctx, nil, eventIDs)
}
// EventTypeNIDs implements state.RoomStateDatabase
func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return d.statements.bulkSelectEventTypeNID(ctx, nil, eventTypes)
}
// EventStateKeyNIDs implements state.RoomStateDatabase
func (d *Database) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return d.statements.bulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
}
// EventStateKeys implements query.RoomserverQueryAPIDatabase
func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
return d.statements.bulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
}
// EventNIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.statements.bulkSelectEventNID(ctx, nil, eventIDs)
}
// Events implements input.EventDatabase
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
var eventJSONs []eventJSONPair
var err error
results := make([]types.Event, len(eventJSONs))
common.WithTransaction(d.db, func(txn *sql.Tx) error {
eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil {
return nil
}
for i, eventJSON := range eventJSONs {
result := &results[i]
result.EventNID = eventJSON.EventNID
// TODO: Use NewEventFromTrustedJSON for efficiency
result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON)
if err != nil {
return nil
}
}
return nil
})
if err != nil {
return []types.Event{}, err
}
return results, nil
}
// AddState implements input.EventDatabase
func (d *Database) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
common.WithTransaction(d.db, func(txn *sql.Tx) error {
if len(state) > 0 {
stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx, txn)
if err != nil {
return err
}
if err = d.statements.bulkInsertStateData(ctx, txn, stateBlockNID, state); err != nil {
return err
}
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
}
stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)
return nil
})
if err != nil {
return 0, err
}
return
}
// SetState implements input.EventDatabase
func (d *Database) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return d.statements.updateEventState(ctx, nil, eventNID, stateNID)
}
// StateAtEventIDs implements input.EventDatabase
func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return d.statements.bulkSelectStateAtEventByID(ctx, nil, eventIDs)
}
// StateBlockNIDs implements state.RoomStateDatabase
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.statements.bulkSelectStateBlockNIDs(ctx, nil, stateNIDs)
}
// StateEntries implements state.RoomStateDatabase
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return d.statements.bulkSelectStateBlockEntries(ctx, nil, stateBlockNIDs)
}
// SnapshotNIDFromEventID implements state.RoomStateDatabase
func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (stateNID types.StateSnapshotNID, err error) {
_, stateNID, err = d.statements.selectEvent(ctx, nil, eventID)
return
}
// EventIDs implements input.RoomEventDatabase
func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return d.statements.bulkSelectEventID(ctx, nil, eventNIDs)
}
// GetLatestEventsForUpdate implements input.EventDatabase
func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID,
) (types.RoomRecentEventsUpdater, error) {
txn, err := d.db.Begin()
if err != nil {
return nil, err
}
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
if err != nil {
txn.Rollback() // nolint: errcheck
return nil, err
}
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
txn.Rollback() // nolint: errcheck
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
txn.Rollback() // nolint: errcheck
return nil, err
}
}
return &roomRecentEventsUpdater{
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil
}
// GetTransactionEventID implements input.EventDatabase
func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string,
sessionID int64, userID string,
) (string, error) {
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
if err == sql.ErrNoRows {
return "", nil
}
return eventID, err
}
type roomRecentEventsUpdater struct {
transaction
d *Database
roomNID types.RoomNID
latestEvents []types.StateAtEventAndReference
lastEventIDSent string
currentStateSnapshotNID types.StateSnapshotNID
}
// LatestEvents implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents
}
// LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) LastEventIDSent() string {
return u.lastEventIDSent
}
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences {
if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return err
}
}
return nil
}
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
// SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
// TODO: transaction was removed here - is this wise?
return u.d.statements.updateLatestEventNIDs(u.ctx, nil, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
// TODO: transaction was removed here - is this wise?
return u.d.statements.selectEventSentToOutput(u.ctx, nil, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
// TODO: transaction was removed here - is this wise?
return u.d.statements.updateEventSentToOutput(u.ctx, nil, eventNID)
}
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
// TODO: transaction was removed here - is this wise?
return u.d.membershipUpdaterTxn(u.ctx, nil, u.roomNID, targetUserNID)
}
// RoomNID implements query.RoomserverQueryAPIDB
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID)
if err == sql.ErrNoRows {
return 0, nil
}
return roomNID, err
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var eventNIDs []types.EventNID
eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID)
if err != nil {
return err
}
references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs)
if err != nil {
return err
}
depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs)
if err != nil {
return err
}
return nil
})
return
}
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
func (d *Database) GetInvitesForUser(
ctx context.Context,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, err error) {
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
}
// SetRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID)
}
// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.statements.selectRoomIDFromAlias(ctx, nil, alias)
}
// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.statements.selectAliasesFromRoomID(ctx, nil, roomID)
}
// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
return d.statements.selectCreatorIDFromAlias(ctx, nil, alias)
}
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
return d.statements.deleteRoomAlias(ctx, nil, alias)
}
// StateEntriesForTuples implements state.RoomStateDatabase
func (d *Database) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.statements.bulkSelectFilteredStateBlockEntries(
ctx, nil, stateBlockNIDs, stateKeyTuples,
)
}
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
) (types.MembershipUpdater, error) {
txn, err := d.db.Begin()
if err != nil {
return nil, err
}
succeeded := false
defer func() {
if !succeeded {
txn.Rollback() // nolint: errcheck
}
}()
roomNID, err := d.assignRoomNID(ctx, txn, roomID)
if err != nil {
return nil, err
}
targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
if err != nil {
return nil, err
}
updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
if err != nil {
return nil, err
}
succeeded = true
return updater, nil
}
type membershipUpdater struct {
transaction
d *Database
roomNID types.RoomNID
targetUserNID types.EventStateKeyNID
membership membershipState
}
func (d *Database) membershipUpdaterTxn(
ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (types.MembershipUpdater, error) {
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
return nil, err
}
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
if err != nil {
return nil, err
}
return &membershipUpdater{
transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
}, nil
}
// IsInvite implements types.MembershipUpdater
func (u *membershipUpdater) IsInvite() bool {
return u.membership == membershipStateInvite
}
// IsJoin implements types.MembershipUpdater
func (u *membershipUpdater) IsJoin() bool {
return u.membership == membershipStateJoin
}
// IsLeave implements types.MembershipUpdater
func (u *membershipUpdater) IsLeave() bool {
return u.membership == membershipStateLeaveOrBan
}
// SetToInvite implements types.MembershipUpdater
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
if err != nil {
return false, err
}
inserted, err := u.d.statements.insertInviteEvent(
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
return false, err
}
if u.membership != membershipStateInvite {
if err = u.d.statements.updateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
); err != nil {
return false, err
}
}
return inserted, nil
}
// SetToJoin implements types.MembershipUpdater
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
var inviteEventIDs []string
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil {
return nil, err
}
// If this is a join event update, there is no invite to update
if !isUpdate {
inviteEventIDs, err = u.d.statements.updateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return nil, err
}
}
// Look up the NID of the new join event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return nil, err
}
if u.membership != membershipStateJoin || isUpdate {
if err = u.d.statements.updateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateJoin, nIDs[eventID],
); err != nil {
return nil, err
}
}
return inviteEventIDs, nil
}
// SetToLeave implements types.MembershipUpdater
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil {
return nil, err
}
inviteEventIDs, err := u.d.statements.updateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return nil, err
}
// Look up the NID of the new leave event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return nil, err
}
if u.membership != membershipStateLeaveOrBan {
if err = u.d.statements.updateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
return nil, err
}
}
return inviteEventIDs, nil
}
// GetMembership implements query.RoomserverQueryAPIDB
func (d *Database) GetMembership(
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID)
if err != nil {
return err
}
membershipEventNID, _, err =
d.statements.selectMembershipFromRoomAndTarget(
ctx, txn, roomNID, requestSenderUserNID,
)
if err == sql.ErrNoRows {
// The user has never been a member of that room
return nil
}
if err != nil {
return err
}
stillInRoom = true
return nil
})
return
}
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
) (eventNIDs []types.EventNID, err error) {
common.WithTransaction(d.db, func(txn *sql.Tx) error {
if joinOnly {
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
ctx, txn, roomNID, membershipStateJoin,
)
return nil
}
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID)
return nil
})
return
}
// EventsFromIDs implements query.RoomserverQueryAPIEventDB
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
var nids []types.EventNID
for _, nid := range nidMap {
nids = append(nids, nid)
}
return d.Events(ctx, nids)
}
type transaction struct {
ctx context.Context
txn *sql.Tx
}
// Commit implements types.Transaction
func (t *transaction) Commit() error {
return t.txn.Commit()
}
// Rollback implements types.Transaction
func (t *transaction) Rollback() error {
return t.txn.Rollback()
}

View file

@ -0,0 +1,89 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
)
const transactionsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_transactions (
transaction_id TEXT NOT NULL,
session_id INTEGER NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
PRIMARY KEY (transaction_id, session_id, user_id)
);
`
const insertTransactionSQL = `
INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
VALUES ($1, $2, $3, $4)
`
const selectTransactionEventIDSQL = `
SELECT event_id FROM roomserver_transactions
WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
`
type transactionStatements struct {
insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt
}
func (s *transactionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(transactionsSchema)
if err != nil {
return
}
return statementList{
{&s.insertTransactionStmt, insertTransactionSQL},
{&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
}.prepare(db)
}
func (s *transactionStatements) insertTransaction(
ctx context.Context,
transactionID string,
sessionID int64,
userID string,
eventID string,
) (err error) {
_, err = s.insertTransactionStmt.ExecContext(
ctx, transactionID, sessionID, userID, eventID,
)
if err != nil {
fmt.Println("insertTransaction s.insertTransactionStmt.ExecContent:", err)
}
return
}
func (s *transactionStatements) selectTransactionEventID(
ctx context.Context,
transactionID string,
sessionID int64,
userID string,
) (eventID string, err error) {
err = s.selectTransactionEventIDStmt.QueryRowContext(
ctx, transactionID, sessionID, userID,
).Scan(&eventID)
if err != nil {
fmt.Println("selectTransactionEventID s.selectTransactionEventIDStmt.QueryRowContext:", err)
}
return
}

View file

@ -19,25 +19,20 @@ import (
"net/url" "net/url"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type Database interface { type Database interface {
state.RoomStateDatabase
StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error)
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error)
GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error) GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error)
@ -49,7 +44,6 @@ type Database interface {
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) GetCreatorIDForAlias(ctx context.Context, alias string) (string, error)
RemoveRoomAlias(ctx context.Context, alias string) error RemoveRoomAlias(ctx context.Context, alias string) error
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
MembershipUpdater(ctx context.Context, roomID, targetUserID string) (types.MembershipUpdater, error) MembershipUpdater(ctx context.Context, roomID, targetUserID string) (types.MembershipUpdater, error)
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error)
@ -65,6 +59,8 @@ func Open(dataSourceName string) (Database, error) {
switch uri.Scheme { switch uri.Scheme {
case "postgres": case "postgres":
return postgres.Open(dataSourceName) return postgres.Open(dataSourceName)
case "file":
return sqlite3.Open(dataSourceName)
default: default:
return postgres.Open(dataSourceName) return postgres.Open(dataSourceName)
} }