Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking

This commit is contained in:
Till Faelligen 2022-03-16 08:43:19 +01:00
commit b9479a6f18
33 changed files with 279 additions and 156 deletions

View file

@ -59,8 +59,8 @@ global:
# Lists of domains that the server will trust as identity servers to verify third # Lists of domains that the server will trust as identity servers to verify third
# party identifiers such as phone numbers and email addresses. # party identifiers such as phone numbers and email addresses.
trusted_third_party_id_servers: trusted_third_party_id_servers:
- matrix.org - matrix.org
- vector.im - vector.im
# Configuration for NATS JetStream # Configuration for NATS JetStream
jetstream: jetstream:
@ -204,12 +204,12 @@ federation_api:
# be required to satisfy key requests for servers that are no longer online when # be required to satisfy key requests for servers that are no longer online when
# joining some rooms. # joining some rooms.
key_perspectives: key_perspectives:
- server_name: matrix.org - server_name: matrix.org
keys: keys:
- key_id: ed25519:auto - key_id: ed25519:auto
public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw
- key_id: ed25519:a_RXGa - key_id: ed25519:a_RXGa
public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
# This option will control whether Dendrite will prefer to look up keys directly # This option will control whether Dendrite will prefer to look up keys directly
# or whether it should try perspective servers first, using direct fetches as a # or whether it should try perspective servers first, using direct fetches as a
@ -255,15 +255,15 @@ media_api:
# A list of thumbnail sizes to be generated for media content. # A list of thumbnail sizes to be generated for media content.
thumbnail_sizes: thumbnail_sizes:
- width: 32 - width: 32
height: 32 height: 32
method: crop method: crop
- width: 96 - width: 96
height: 96 height: 96
method: crop method: crop
- width: 640 - width: 640
height: 480 height: 480
method: scale method: scale
# Configuration for experimental MSC's # Configuration for experimental MSC's
mscs: mscs:
@ -295,7 +295,7 @@ sync_api:
listen: http://0.0.0.0:7773 listen: http://0.0.0.0:7773
connect: http://sync_api:7773 connect: http://sync_api:7773
external_api: external_api:
listen: http://0.0.0.0:8073 listen: http://0.0.0.0:8073
database: database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_syncapi?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_syncapi?sslmode=disable
max_open_conns: 10 max_open_conns: 10
@ -312,11 +312,6 @@ user_api:
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_userapi_devices?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# Configuration for the Push Server API. # Configuration for the Push Server API.
push_server: push_server:
@ -348,7 +343,7 @@ tracing:
# Logging configuration, in addition to the standard logging that is sent to # Logging configuration, in addition to the standard logging that is sent to
# stdout by Dendrite. # stdout by Dendrite.
logging: logging:
- type: file - type: file
level: info level: info
params: params:
path: /var/log/dendrite path: /var/log/dendrite

View file

@ -1,5 +1,5 @@
#!/bin/sh #!/bin/sh
for db in userapi_accounts userapi_devices mediaapi syncapi roomserver keyserver federationapi appservice mscs; do for db in userapi_accounts mediaapi syncapi roomserver keyserver federationapi appservice mscs; do
createdb -U dendrite -O dendrite dendrite_$db createdb -U dendrite -O dendrite dendrite_$db
done done

View file

@ -21,16 +21,15 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"regexp"
"strings" "strings"
"github.com/matrix-org/dendrite/setup/base"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"golang.org/x/term" "golang.org/x/term"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
) )
const usage = `Usage: %s const usage = `Usage: %s
@ -56,13 +55,14 @@ Arguments:
` `
var ( var (
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use") askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
isAdmin = flag.Bool("admin", false, "Create an admin account") isAdmin = flag.Bool("admin", false, "Create an admin account")
resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
) )
func main() { func main() {
@ -78,25 +78,21 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if !validUsernameRegex.MatchString(*username) {
logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='")
os.Exit(1)
}
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
accountDB, err := userdb.NewDatabase( b := base.NewBaseDendrite(cfg, "create-account")
&config.DatabaseOptions{ accountDB := b.CreateAccountsDB()
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
},
cfg.Global.ServerName, bcrypt.DefaultCost,
cfg.UserAPI.OpenIDTokenLifetimeMS,
api.DefaultLoginTokenLifetime,
)
if err != nil {
logrus.Fatalln("Failed to connect to the database:", err.Error())
}
accType := api.AccountTypeUser accType := api.AccountTypeUser
if *isAdmin { if *isAdmin {
accType = api.AccountTypeAdmin accType = api.AccountTypeAdmin
} }
var err error
if *resetPassword { if *resetPassword {
err = accountDB.SetPassword(context.Background(), *username, pass) err = accountDB.SetPassword(context.Background(), *username, pass)
if err != nil { if err != nil {

View file

@ -87,7 +87,7 @@ On macOS, omit `sudo -u postgres` from the below commands.
* If you want to run each Dendrite component with its own database: * If you want to run each Dendrite component with its own database:
```bash ```bash
for i in mediaapi syncapi roomserver federationapi appservice keyserver userapi_accounts userapi_devices; do for i in mediaapi syncapi roomserver federationapi appservice keyserver userapi_accounts; do
sudo -u postgres createdb -O dendrite dendrite_$i sudo -u postgres createdb -O dendrite dendrite_$i
done done
``` ```

View file

@ -203,9 +203,9 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
return err == nil return err == nil
} }
func prevID(streamID int) []int { func prevID(streamID int64) []int64 {
if streamID <= 1 { if streamID <= 1 {
return nil return nil
} }
return []int{streamID - 1} return []int64{streamID - 1}
} }

4
go.mod
View file

@ -40,8 +40,8 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10 github.com/mattn/go-sqlite3 v1.14.10
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect

8
go.sum
View file

@ -983,10 +983,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE= github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa h1:anEGvpRn4v6akmxFWqGDobB6csEt3OWmp67pufccimE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5 h1:7viLTiLAA2MtGKY+uf14j6TjfKvvGLAMj/qdm70jJuQ=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=

View file

@ -70,7 +70,7 @@ type DeviceMessage struct {
*DeviceKeys `json:"DeviceKeys,omitempty"` *DeviceKeys `json:"DeviceKeys,omitempty"`
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` *eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
// A monotonically increasing number which represents device changes for this user. // A monotonically increasing number which represents device changes for this user.
StreamID int StreamID int64
DeviceChangeID int64 DeviceChangeID int64
} }
@ -108,7 +108,7 @@ type DeviceKeys struct {
} }
// WithStreamID returns a copy of this device message with the given stream ID // WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
return DeviceMessage{ return DeviceMessage{
DeviceKeys: k, DeviceKeys: k,
StreamID: streamID, StreamID: streamID,
@ -281,7 +281,7 @@ type QueryDeviceMessagesRequest struct {
type QueryDeviceMessagesResponse struct { type QueryDeviceMessagesResponse struct {
// The latest stream ID // The latest stream ID
StreamID int StreamID int64
Devices []DeviceMessage Devices []DeviceMessage
Error *KeyError Error *KeyError
} }

View file

@ -109,7 +109,7 @@ type DeviceListUpdaterDatabase interface {
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -46,7 +46,7 @@ func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) erro
type mockDeviceListUpdaterDatabase struct { type mockDeviceListUpdaterDatabase struct {
staleUsers map[string]bool staleUsers map[string]bool
prevIDsExist func(string, []int) bool prevIDsExist func(string, []int64) bool
storedKeys []api.DeviceMessage storedKeys []api.DeviceMessage
mu sync.Mutex // protect staleUsers mu sync.Mutex // protect staleUsers
} }
@ -101,7 +101,7 @@ func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Contex
} }
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
return d.prevIDsExist(userID, prevIDs), nil return d.prevIDsExist(userID, prevIDs), nil
} }
@ -139,7 +139,7 @@ func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrix
func TestUpdateHavePrevID(t *testing.T) { func TestUpdateHavePrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{ db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool), staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool { prevIDsExist: func(string, []int64) bool {
return true return true
}, },
} }
@ -151,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) {
Deleted: false, Deleted: false,
DeviceID: "FOO", DeviceID: "FOO",
Keys: []byte(`{"key":"value"}`), Keys: []byte(`{"key":"value"}`),
PrevID: []int{0}, PrevID: []int64{0},
StreamID: 1, StreamID: 1,
UserID: "@alice:localhost", UserID: "@alice:localhost",
} }
@ -185,7 +185,7 @@ func TestUpdateHavePrevID(t *testing.T) {
func TestUpdateNoPrevID(t *testing.T) { func TestUpdateNoPrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{ db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool), staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool { prevIDsExist: func(string, []int64) bool {
return false return false
}, },
} }
@ -226,7 +226,7 @@ func TestUpdateNoPrevID(t *testing.T) {
Deleted: false, Deleted: false,
DeviceID: "another_device_id", DeviceID: "another_device_id",
Keys: []byte(`{"key":"value"}`), Keys: []byte(`{"key":"value"}`),
PrevID: []int{3}, PrevID: []int64{3},
StreamID: 4, StreamID: 4,
UserID: remoteUserID, UserID: remoteUserID,
} }
@ -268,7 +268,7 @@ func TestDebounce(t *testing.T) {
t.Skipf("panic on closed channel on GHA") t.Skipf("panic on closed channel on GHA")
db := &mockDeviceListUpdaterDatabase{ db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool), staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool { prevIDsExist: func(string, []int64) bool {
return true return true
}, },
} }

View file

@ -205,7 +205,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
} }
return return
} }
maxStreamID := 0 maxStreamID := int64(0)
for _, m := range msgs { for _, m := range msgs {
if m.StreamID > maxStreamID { if m.StreamID > maxStreamID {
maxStreamID = m.StreamID maxStreamID = m.StreamID

View file

@ -49,7 +49,7 @@ type Database interface {
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.

View file

@ -121,7 +121,7 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
var streamID int var streamID int64
var displayName sql.NullString var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -138,15 +138,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return nil return nil
} }
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
if nullStream.Valid { if nullStream.Valid {
streamID = nullStream.Int32 streamID = nullStream.Int64
} }
return return
} }
@ -211,7 +211,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
} }
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int64
var displayName sql.NullString var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err return nil, err

View file

@ -59,12 +59,8 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage)
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
} }
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
sids := make([]int64, len(prevIDs)) count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
for i := range prevIDs {
sids[i] = int64(prevIDs[i])
}
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -85,7 +81,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
// work out the latest stream IDs for each user // work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int) userIDToStreamID := make(map[string]int64)
for _, k := range keys { for _, k := range keys {
userIDToStreamID[k.UserID] = 0 userIDToStreamID[k.UserID] = 0
} }
@ -95,7 +91,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
if err != nil { if err != nil {
return err return err
} }
userIDToStreamID[userID] = int(streamID) userIDToStreamID[userID] = streamID
} }
// set the stream IDs for each key // set the stream IDs for each key
for i := range keys { for i := range keys {

View file

@ -145,7 +145,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.Type = api.TypeDeviceKeyUpdate dk.Type = api.TypeDeviceKeyUpdate
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int64
var displayName sql.NullString var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err return nil, err
@ -166,7 +166,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
var streamID int var streamID int64
var displayName sql.NullString var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -183,15 +183,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return nil return nil
} }
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
if nullStream.Valid { if nullStream.Valid {
streamID = nullStream.Int32 streamID = nullStream.Int64
} }
return return
} }
@ -204,13 +204,13 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
} }
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
// nullable if there are no results // nullable if there are no results
var count sql.NullInt32 var count sql.NullInt64
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if count.Valid { if count.Valid {
return int(count.Int32), nil return int(count.Int64), nil
} }
return 0, nil return 0, nil
} }

View file

@ -177,7 +177,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err) t.Fatalf("DeviceKeysForUser returned error: %s", err)
} }
wantStreamIDs := map[string]int{ wantStreamIDs := map[string]int64{
"AAA": 3, "AAA": 3,
"another_device": 2, "another_device": 2,
} }

View file

@ -37,7 +37,7 @@ type OneTimeKeys interface {
type DeviceKeys interface { type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error

View file

@ -610,6 +610,14 @@ func (r *Queryer) QueryPublishedRooms(
req *api.QueryPublishedRoomsRequest, req *api.QueryPublishedRoomsRequest,
res *api.QueryPublishedRoomsResponse, res *api.QueryPublishedRoomsResponse,
) error { ) error {
if req.RoomID != "" {
visible, err := r.DB.GetPublishedRoom(ctx, req.RoomID)
if err == nil && visible {
res.RoomIDs = []string{req.RoomID}
return nil
}
return err
}
rooms, err := r.DB.GetPublishedRooms(ctx) rooms, err := r.DB.GetPublishedRooms(ctx)
if err != nil { if err != nil {
return err return err

View file

@ -139,6 +139,8 @@ type Database interface {
PublishRoom(ctx context.Context, roomID string, publish bool) error PublishRoom(ctx context.Context, roomID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
GetPublishedRooms(ctx context.Context) ([]string, error) GetPublishedRooms(ctx context.Context) ([]string, error)
// Returns whether a given room is published or not.
GetPublishedRoom(ctx context.Context, roomID string) (bool, error)
// TODO: factor out - from currentstateserver // TODO: factor out - from currentstateserver

View file

@ -669,6 +669,10 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
}) })
} }
func (d *Database) GetPublishedRoom(ctx context.Context, roomID string) (bool, error) {
return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID)
}
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
} }

View file

@ -99,6 +99,7 @@ const (
// The componentName is used for logging purposes, and should be a friendly name // The componentName is used for logging purposes, and should be a friendly name
// of the compontent running, e.g. "SyncAPI" // of the compontent running, e.g. "SyncAPI"
func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...BaseDendriteOptions) *BaseDendrite { func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...BaseDendriteOptions) *BaseDendrite {
platformSanityChecks()
useHTTPAPIs := false useHTTPAPIs := false
cacheMetrics := true cacheMetrics := true
for _, opt := range options { for _, opt := range options {

View file

@ -0,0 +1,8 @@
//go:build !linux && !darwin && !netbsd && !freebsd && !openbsd && !solaris && !dragonfly && !aix
// +build !linux,!darwin,!netbsd,!freebsd,!openbsd,!solaris,!dragonfly,!aix
package base
func platformSanityChecks() {
// Nothing to do yet.
}

22
setup/base/sanity_unix.go Normal file
View file

@ -0,0 +1,22 @@
//go:build linux || darwin || netbsd || freebsd || openbsd || solaris || dragonfly || aix
// +build linux darwin netbsd freebsd openbsd solaris dragonfly aix
package base
import (
"syscall"
"github.com/sirupsen/logrus"
)
func platformSanityChecks() {
// Dendrite needs a relatively high number of file descriptors in order
// to function properly, particularly when federating with lots of servers.
// If we run out of file descriptors, we might run into problems accessing
// PostgreSQL amongst other things. Complain at startup if we think the
// number of file descriptors is too low.
var rLimit syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err == nil && rLimit.Cur < 65535 {
logrus.Warnf("IMPORTANT: Process file descriptor limit is currently %d, it is recommended to raise the limit for Dendrite to at least 65535 to avoid issues", rLimit.Cur)
}
}

View file

@ -200,11 +200,6 @@ user_api:
max_open_conns: 100 max_open_conns: 100
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database:
connection_string: file:userapi_devices.db
max_open_conns: 100
max_idle_conns: 2
conn_max_lifetime: -1
pusher_database: pusher_database:
connection_string: file:pushserver.db connection_string: file:pushserver.db
max_open_conns: 100 max_open_conns: 100

View file

@ -66,7 +66,7 @@ func Context(
membershipRes := roomserver.QueryMembershipForUserResponse{} membershipRes := roomserver.QueryMembershipForUserResponse{}
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID}
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
logrus.WithError(err).Error("unable to fo membership") logrus.WithError(err).Error("unable to query membership")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -158,17 +158,19 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter
} }
newState := []*gomatrixserverlib.HeaderedEvent{} newState := []*gomatrixserverlib.HeaderedEvent{}
membershipEvents := []*gomatrixserverlib.HeaderedEvent{}
for _, event := range state { for _, event := range state {
if event.Type() != gomatrixserverlib.MRoomMember { if event.Type() != gomatrixserverlib.MRoomMember {
newState = append(newState, event) newState = append(newState, event)
} else { } else {
// did the user send an event? // did the user send an event?
if x[event.Sender()] { if x[event.Sender()] {
newState = append(newState, event) membershipEvents = append(membershipEvents, event)
} }
} }
} }
return newState // Add the membershipEvents to the end of the list, to make Sytest happy
return append(newState, membershipEvents...)
} }
func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {

View file

@ -77,6 +77,9 @@ const DeleteRoomStateForRoomSQL = "" +
const selectRoomIDsWithMembershipSQL = "" + const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectRoomIDsWithAnyMembershipSQL = "" +
"SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
@ -102,14 +105,15 @@ const selectEventsWithEventIDsSQL = "" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)" " FROM syncapi_current_room_state WHERE event_id = ANY($1)"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
DeleteRoomStateForRoomStmt *sql.Stmt DeleteRoomStateForRoomStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt
selectEventsWithEventIDsStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
} }
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -130,6 +134,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
return nil, err
}
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
return nil, err return nil, err
} }
@ -194,6 +201,31 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
return result, rows.Err() return result, rows.Err()
} }
// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user.
func (s *currentRoomStateStatements) SelectRoomIDsWithAnyMembership(
ctx context.Context,
txn *sql.Tx,
userID string,
) (map[string]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithAnyMembershipStmt)
rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithAnyMembership: rows.close() failed")
result := map[string]string{}
for rows.Next() {
var roomID string
var membership string
if err := rows.Scan(&roomID, &membership); err != nil {
return nil, err
}
result[roomID] = membership
}
return result, rows.Err()
}
// SelectCurrentState returns all the current state events for the given room. // SelectCurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) SelectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,

View file

@ -119,13 +119,14 @@ const selectStateInRangeSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" AND ( $3::text[] IS NULL OR sender = ANY($3) )" + " AND room_id = ANY($3)" +
" AND ( $4::text[] IS NULL OR NOT(sender = ANY($4)) )" + " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR type LIKE ANY($5) )" + " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR NOT(type LIKE ANY($6)) )" + " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
" AND ( $7::bool IS NULL OR contains_url = $7 )" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" AND ( $8::bool IS NULL OR contains_url = $8 )" +
" ORDER BY id ASC" + " ORDER BY id ASC" +
" LIMIT $8" " LIMIT $9"
const deleteEventsForRoomSQL = "" + const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1" "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
@ -200,12 +201,12 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) SelectStateInRange( func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range, ctx context.Context, txn *sql.Tx, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, r.Low(), r.High(), ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(stateFilter.Senders), pq.StringArray(stateFilter.Senders),
pq.StringArray(stateFilter.NotSenders), pq.StringArray(stateFilter.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),

View file

@ -689,10 +689,26 @@ func (d *Database) GetStateDeltas(
var succeeded bool var succeeded bool
defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
// Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
if err != nil {
return nil, nil, err
}
allRoomIDs := make([]string, 0, len(memberships))
joinedRoomIDs := make([]string, 0, len(memberships))
for roomID, membership := range memberships {
allRoomIDs = append(allRoomIDs, roomID)
if membership == gomatrixserverlib.Join {
joinedRoomIDs = append(joinedRoomIDs, roomID)
}
}
var deltas []types.StateDelta var deltas []types.StateDelta
// get all the state events ever (i.e. for all available rooms) between these two positions // get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -760,10 +776,6 @@ func (d *Database) GetStateDeltas(
} }
// Add in currently joined rooms // Add in currently joined rooms
joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil {
return nil, nil, err
}
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: gomatrixserverlib.Join, Membership: gomatrixserverlib.Join,
@ -792,6 +804,22 @@ func (d *Database) GetStateDeltasForFullStateSync(
var succeeded bool var succeeded bool
defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
// Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
if err != nil {
return nil, nil, err
}
allRoomIDs := make([]string, 0, len(memberships))
joinedRoomIDs := make([]string, 0, len(memberships))
for roomID, membership := range memberships {
allRoomIDs = append(allRoomIDs, roomID)
if membership == gomatrixserverlib.Join {
joinedRoomIDs = append(joinedRoomIDs, roomID)
}
}
// Use a reasonable initial capacity // Use a reasonable initial capacity
deltas := make(map[string]types.StateDelta) deltas := make(map[string]types.StateDelta)
@ -816,7 +844,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
} }
// Get all the state events ever between these two positions // Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -842,11 +870,6 @@ func (d *Database) GetStateDeltasForFullStateSync(
} }
} }
joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil {
return nil, nil, err
}
// Add full states for all joined rooms // Add full states for all joined rooms
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter)

View file

@ -66,6 +66,9 @@ const DeleteRoomStateForRoomSQL = "" +
const selectRoomIDsWithMembershipSQL = "" + const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectRoomIDsWithAnyMembershipSQL = "" +
"SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
@ -86,14 +89,15 @@ const selectEventsWithEventIDsSQL = "" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)" " FROM syncapi_current_room_state WHERE event_id IN ($1)"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
DeleteRoomStateForRoomStmt *sql.Stmt DeleteRoomStateForRoomStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
@ -117,6 +121,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
return nil, err
}
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err return nil, err
} }
@ -175,6 +182,31 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
return result, nil return result, nil
} }
// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user.
func (s *currentRoomStateStatements) SelectRoomIDsWithAnyMembership(
ctx context.Context,
txn *sql.Tx,
userID string,
) (map[string]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithAnyMembershipStmt)
rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithAnyMembership: rows.close() failed")
result := map[string]string{}
for rows.Next() {
var roomID string
var membership string
if err := rows.Scan(&roomID, &membership); err != nil {
return nil, err
}
result[roomID] = membership
}
return result, rows.Err()
}
// CurrentState returns all the current state events for the given room. // CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) SelectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,

View file

@ -21,6 +21,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort" "sort"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -87,6 +88,7 @@ const selectStateInRangeSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" + " WHERE (id > $1 AND id <= $2)" +
" AND room_id IN ($3)" +
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
@ -155,13 +157,17 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) SelectStateInRange( func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range, ctx context.Context, txn *sql.Tx, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmtSQL := strings.Replace(selectStateInRangeSQL, "($3)", sqlutil.QueryVariadicOffset(len(roomIDs), 2), 1)
inputParams := []interface{}{
r.Low(), r.High(),
}
for _, roomID := range roomIDs {
inputParams = append(inputParams, roomID)
}
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, txn, selectStateInRangeSQL, s.db, txn, stmtSQL, inputParams,
[]interface{}{
r.Low(), r.High(),
},
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.Limit, FilterOrderAsc, nil, stateFilter.Limit, FilterOrderAsc,

View file

@ -51,7 +51,7 @@ type Peeks interface {
} }
type Events interface { type Events interface {
SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error) SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string) (map[string]map[string]bool, map[string]types.StreamEvent, error)
SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error)
InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error) InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error)
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
@ -99,6 +99,8 @@ type CurrentRoomState interface {
SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user.
SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error)
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
SelectJoinedUsers(ctx context.Context) (map[string][]string, error) SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
} }

View file

@ -24,7 +24,6 @@ Local device key changes get to remote servers with correct prev_id
# Flakey # Flakey
Local device key changes appear in /keys/changes Local device key changes appear in /keys/changes
/context/ with lazy_load_members filter works
# we don't support groups # we don't support groups
Remove group category Remove group category
@ -32,9 +31,10 @@ Remove group role
# Flakey # Flakey
AS-ghosted users can use rooms themselves AS-ghosted users can use rooms themselves
/context/ with lazy_load_members filter works
AS-ghosted users can use rooms via AS AS-ghosted users can use rooms via AS
Events in rooms with AS-hosted room aliases are sent to AS server Events in rooms with AS-hosted room aliases are sent to AS server
Inviting an AS-hosted user asks the AS server
Accesing an AS-hosted room alias asks the AS server
# Flakey, need additional investigation # Flakey, need additional investigation
Messages that notify from another user increment notification_count Messages that notify from another user increment notification_count

View file

@ -515,7 +515,6 @@ AS can create a user with inhibit_login
AS can set avatar for ghosted users AS can set avatar for ghosted users
AS can set displayname for ghosted users AS can set displayname for ghosted users
Ghost user must register before joining room Ghost user must register before joining room
Inviting an AS-hosted user asks the AS server
Can generate a openid access_token that can be exchanged for information about a user Can generate a openid access_token that can be exchanged for information about a user
Invalid openid access tokens are rejected Invalid openid access tokens are rejected
Requests to userinfo without access tokens are rejected Requests to userinfo without access tokens are rejected
@ -661,6 +660,5 @@ Multiple calls to /sync should not cause 500 errors
Canonical alias can be set Canonical alias can be set
Canonical alias can include alt_aliases Canonical alias can include alt_aliases
Can delete canonical alias Can delete canonical alias
Multiple calls to /sync should not cause 500 errors
AS can make room aliases AS can make room aliases
Accesing an AS-hosted room alias asks the AS server /context/ with lazy_load_members filter works