refresh dendrite main

This commit is contained in:
Tak Wai Wong 2022-03-10 10:20:43 -08:00
parent 7771e5fac9
commit 3c0185d610
20 changed files with 95 additions and 79 deletions

View file

@ -312,8 +312,14 @@ user_api:
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_userapi_devices?sslmode=disable # Configuration for the Push Server API.
push_server:
internal_api:
listen: http://localhost:7782
connect: http://localhost:7782
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -1,5 +1,5 @@
#!/bin/sh #!/bin/sh
for db in userapi_accounts userapi_devices mediaapi syncapi roomserver keyserver federationapi appservice mscs; do for db in userapi_accounts mediaapi syncapi roomserver keyserver federationapi appservice mscs; do
createdb -U dendrite -O dendrite dendrite_$db createdb -U dendrite -O dendrite dendrite_$db
done done

View file

@ -87,7 +87,7 @@ On macOS, omit `sudo -u postgres` from the below commands.
* If you want to run each Dendrite component with its own database: * If you want to run each Dendrite component with its own database:
```bash ```bash
for i in mediaapi syncapi roomserver federationapi appservice keyserver userapi_accounts userapi_devices; do for i in mediaapi syncapi roomserver federationapi appservice keyserver userapi_accounts; do
sudo -u postgres createdb -O dendrite dendrite_$i sudo -u postgres createdb -O dendrite dendrite_$i
done done
``` ```

View file

@ -203,9 +203,9 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
return err == nil return err == nil
} }
func prevID(streamID int) []int { func prevID(streamID int64) []int64 {
if streamID <= 1 { if streamID <= 1 {
return nil return nil
} }
return []int{streamID - 1} return []int64{streamID - 1}
} }

4
go.mod
View file

@ -40,8 +40,8 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10 github.com/mattn/go-sqlite3 v1.14.10
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect

8
go.sum
View file

@ -983,10 +983,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE= github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa h1:anEGvpRn4v6akmxFWqGDobB6csEt3OWmp67pufccimE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5 h1:7viLTiLAA2MtGKY+uf14j6TjfKvvGLAMj/qdm70jJuQ=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=

View file

@ -70,7 +70,7 @@ type DeviceMessage struct {
*DeviceKeys `json:"DeviceKeys,omitempty"` *DeviceKeys `json:"DeviceKeys,omitempty"`
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` *eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
// A monotonically increasing number which represents device changes for this user. // A monotonically increasing number which represents device changes for this user.
StreamID int StreamID int64
DeviceChangeID int64 DeviceChangeID int64
} }
@ -108,7 +108,7 @@ type DeviceKeys struct {
} }
// WithStreamID returns a copy of this device message with the given stream ID // WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
return DeviceMessage{ return DeviceMessage{
DeviceKeys: k, DeviceKeys: k,
StreamID: streamID, StreamID: streamID,
@ -281,7 +281,7 @@ type QueryDeviceMessagesRequest struct {
type QueryDeviceMessagesResponse struct { type QueryDeviceMessagesResponse struct {
// The latest stream ID // The latest stream ID
StreamID int StreamID int64
Devices []DeviceMessage Devices []DeviceMessage
Error *KeyError Error *KeyError
} }

View file

@ -109,7 +109,7 @@ type DeviceListUpdaterDatabase interface {
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -46,7 +46,7 @@ func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) erro
type mockDeviceListUpdaterDatabase struct { type mockDeviceListUpdaterDatabase struct {
staleUsers map[string]bool staleUsers map[string]bool
prevIDsExist func(string, []int) bool prevIDsExist func(string, []int64) bool
storedKeys []api.DeviceMessage storedKeys []api.DeviceMessage
mu sync.Mutex // protect staleUsers mu sync.Mutex // protect staleUsers
} }
@ -101,7 +101,7 @@ func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Contex
} }
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
return d.prevIDsExist(userID, prevIDs), nil return d.prevIDsExist(userID, prevIDs), nil
} }
@ -139,7 +139,7 @@ func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrix
func TestUpdateHavePrevID(t *testing.T) { func TestUpdateHavePrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{ db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool), staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool { prevIDsExist: func(string, []int64) bool {
return true return true
}, },
} }
@ -151,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) {
Deleted: false, Deleted: false,
DeviceID: "FOO", DeviceID: "FOO",
Keys: []byte(`{"key":"value"}`), Keys: []byte(`{"key":"value"}`),
PrevID: []int{0}, PrevID: []int64{0},
StreamID: 1, StreamID: 1,
UserID: "@alice:localhost", UserID: "@alice:localhost",
} }
@ -185,7 +185,7 @@ func TestUpdateHavePrevID(t *testing.T) {
func TestUpdateNoPrevID(t *testing.T) { func TestUpdateNoPrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{ db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool), staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool { prevIDsExist: func(string, []int64) bool {
return false return false
}, },
} }
@ -226,7 +226,7 @@ func TestUpdateNoPrevID(t *testing.T) {
Deleted: false, Deleted: false,
DeviceID: "another_device_id", DeviceID: "another_device_id",
Keys: []byte(`{"key":"value"}`), Keys: []byte(`{"key":"value"}`),
PrevID: []int{3}, PrevID: []int64{3},
StreamID: 4, StreamID: 4,
UserID: remoteUserID, UserID: remoteUserID,
} }
@ -268,7 +268,7 @@ func TestDebounce(t *testing.T) {
t.Skipf("panic on closed channel on GHA") t.Skipf("panic on closed channel on GHA")
db := &mockDeviceListUpdaterDatabase{ db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool), staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool { prevIDsExist: func(string, []int64) bool {
return true return true
}, },
} }

View file

@ -205,7 +205,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
} }
return return
} }
maxStreamID := 0 maxStreamID := int64(0)
for _, m := range msgs { for _, m := range msgs {
if m.StreamID > maxStreamID { if m.StreamID > maxStreamID {
maxStreamID = m.StreamID maxStreamID = m.StreamID

View file

@ -49,7 +49,7 @@ type Database interface {
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.

View file

@ -121,7 +121,7 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
var streamID int var streamID int64
var displayName sql.NullString var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -138,15 +138,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return nil return nil
} }
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
if nullStream.Valid { if nullStream.Valid {
streamID = nullStream.Int32 streamID = nullStream.Int64
} }
return return
} }
@ -211,7 +211,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
} }
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int64
var displayName sql.NullString var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err return nil, err

View file

@ -59,12 +59,8 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage)
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
} }
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
sids := make([]int64, len(prevIDs)) count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
for i := range prevIDs {
sids[i] = int64(prevIDs[i])
}
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -85,7 +81,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
// work out the latest stream IDs for each user // work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int) userIDToStreamID := make(map[string]int64)
for _, k := range keys { for _, k := range keys {
userIDToStreamID[k.UserID] = 0 userIDToStreamID[k.UserID] = 0
} }
@ -95,7 +91,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
if err != nil { if err != nil {
return err return err
} }
userIDToStreamID[userID] = int(streamID) userIDToStreamID[userID] = streamID
} }
// set the stream IDs for each key // set the stream IDs for each key
for i := range keys { for i := range keys {

View file

@ -145,7 +145,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.Type = api.TypeDeviceKeyUpdate dk.Type = api.TypeDeviceKeyUpdate
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int64
var displayName sql.NullString var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err return nil, err
@ -166,7 +166,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
var streamID int var streamID int64
var displayName sql.NullString var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -183,15 +183,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return nil return nil
} }
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
if nullStream.Valid { if nullStream.Valid {
streamID = nullStream.Int32 streamID = nullStream.Int64
} }
return return
} }
@ -204,13 +204,13 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
} }
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
// nullable if there are no results // nullable if there are no results
var count sql.NullInt32 var count sql.NullInt64
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if count.Valid { if count.Valid {
return int(count.Int32), nil return int(count.Int64), nil
} }
return 0, nil return 0, nil
} }

View file

@ -177,7 +177,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err) t.Fatalf("DeviceKeysForUser returned error: %s", err)
} }
wantStreamIDs := map[string]int{ wantStreamIDs := map[string]int64{
"AAA": 3, "AAA": 3,
"another_device": 2, "another_device": 2,
} }

View file

@ -37,7 +37,7 @@ type OneTimeKeys interface {
type DeviceKeys interface { type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error

View file

@ -610,6 +610,14 @@ func (r *Queryer) QueryPublishedRooms(
req *api.QueryPublishedRoomsRequest, req *api.QueryPublishedRoomsRequest,
res *api.QueryPublishedRoomsResponse, res *api.QueryPublishedRoomsResponse,
) error { ) error {
if req.RoomID != "" {
visible, err := r.DB.GetPublishedRoom(ctx, req.RoomID)
if err == nil && visible {
res.RoomIDs = []string{req.RoomID}
return nil
}
return err
}
rooms, err := r.DB.GetPublishedRooms(ctx) rooms, err := r.DB.GetPublishedRooms(ctx)
if err != nil { if err != nil {
return err return err

View file

@ -139,6 +139,8 @@ type Database interface {
PublishRoom(ctx context.Context, roomID string, publish bool) error PublishRoom(ctx context.Context, roomID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
GetPublishedRooms(ctx context.Context) ([]string, error) GetPublishedRooms(ctx context.Context) ([]string, error)
// Returns whether a given room is published or not.
GetPublishedRoom(ctx context.Context, roomID string) (bool, error)
// TODO: factor out - from currentstateserver // TODO: factor out - from currentstateserver

View file

@ -669,6 +669,10 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
}) })
} }
func (d *Database) GetPublishedRoom(ctx context.Context, roomID string) (bool, error) {
return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID)
}
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
} }

View file

@ -200,8 +200,8 @@ user_api:
max_open_conns: 100 max_open_conns: 100
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: pusher_database:
connection_string: file:userapi_devices.db connection_string: file:pushserver.db
max_open_conns: 100 max_open_conns: 100
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1