Merge branch 'master' of https://github.com/matrix-org/dendrite into basicauth-metrics

This commit is contained in:
Till Faelligen 2020-02-13 20:16:34 +01:00
commit 2ccc149836
142 changed files with 9876 additions and 856 deletions

View file

@ -140,7 +140,7 @@ func RetrieveUserProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
asAPI AppServiceQueryAPI, asAPI AppServiceQueryAPI,
accountDB *accounts.Database, accountDB accounts.Database,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -41,8 +41,8 @@ import (
// component. // component.
func SetupAppServiceAPIComponent( func SetupAppServiceAPIComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
accountsDB *accounts.Database, accountsDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
roomserverAliasAPI roomserverAPI.RoomserverAliasAPI, roomserverAliasAPI roomserverAPI.RoomserverAliasAPI,
roomserverQueryAPI roomserverAPI.RoomserverQueryAPI, roomserverQueryAPI roomserverAPI.RoomserverQueryAPI,
@ -100,7 +100,7 @@ func SetupAppServiceAPIComponent(
// Set up HTTP Endpoints // Set up HTTP Endpoints
routing.Setup( routing.Setup(
base.APIMux, *base.Cfg, roomserverQueryAPI, roomserverAliasAPI, base.APIMux, base.Cfg, roomserverQueryAPI, roomserverAliasAPI,
accountsDB, federation, transactionsCache, accountsDB, federation, transactionsCache,
) )
@ -111,8 +111,8 @@ func SetupAppServiceAPIComponent(
// `sender_localpart` field of each application service if it doesn't // `sender_localpart` field of each application service if it doesn't
// exist already // exist already
func generateAppServiceAccount( func generateAppServiceAccount(
accountsDB *accounts.Database, accountsDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
as config.ApplicationService, as config.ApplicationService,
) error { ) error {
ctx := context.Background() ctx := context.Background()

View file

@ -33,7 +33,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer roomServerConsumer *common.ContinualConsumer
db *accounts.Database db accounts.Database
asDB *storage.Database asDB *storage.Database
query api.RoomserverQueryAPI query api.RoomserverQueryAPI
alias api.RoomserverAliasAPI alias api.RoomserverAliasAPI
@ -46,7 +46,7 @@ type OutputRoomEventConsumer struct {
func NewOutputRoomEventConsumer( func NewOutputRoomEventConsumer(
cfg *config.Dendrite, cfg *config.Dendrite,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
store *accounts.Database, store accounts.Database,
appserviceDB *storage.Database, appserviceDB *storage.Database,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
aliasAPI api.RoomserverAliasAPI, aliasAPI api.RoomserverAliasAPI,

View file

@ -36,9 +36,9 @@ const pathPrefixApp = "/_matrix/app/v1"
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
apiMux *mux.Router, cfg config.Dendrite, // nolint: unparam apiMux *mux.Router, cfg *config.Dendrite, // nolint: unparam
queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam
accountDB *accounts.Database, // nolint: unparam accountDB accounts.Database, // nolint: unparam
federation *gomatrixserverlib.FederationClient, // nolint: unparam federation *gomatrixserverlib.FederationClient, // nolint: unparam
transactionsCache *transactions.Cache, // nolint: unparam transactionsCache *transactions.Cache, // nolint: unparam
) { ) {

View file

@ -0,0 +1,141 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
-- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
content TEXT NOT NULL,
PRIMARY KEY(localpart, room_id, type)
);
`
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(accountDataSchema)
if err != nil {
return
}
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
return
}
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
return
}
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
return
}
return
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) (err error) {
stmt := s.insertAccountDataStmt
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return
}
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
return
}
defer rows.Close() // nolint: errcheck
global = []gomatrixserverlib.ClientEvent{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
for rows.Next() {
var roomID string
var dataType string
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return
}
ac := gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
} else {
global = append(global, ac)
}
}
return global, rooms, rows.Err()
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
stmt := s.selectAccountDataByTypeStmt
var content []byte
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return
}
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
return
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"
@ -122,11 +122,10 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
for rows.Next() { for rows.Next() {
var m authtypes.Membership var m authtypes.Membership
m.Localpart = localpart m.Localpart = localpart
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil { if err = rows.Scan(&m.RoomID, &m.EventID); err != nil {
return nil, err return
} }
memberships = append(memberships, m) memberships = append(memberships, m)
} }
return memberships, rows.Err()
return
} }

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"

View file

@ -0,0 +1,392 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
type Database struct {
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
accountDatas accountDataStatements
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package postgres
import ( import (
"context" "context"

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package sqlite3
import ( import (
"context" "context"
@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS account_data (
const insertAccountDataSQL = ` const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +

View file

@ -0,0 +1,151 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
password_hash TEXT,
-- Identifies which application service this account belongs to, if any.
appservice_id TEXT
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
);
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(accountsSchema)
if err != nil {
return
}
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash, appserviceID string,
) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if appserviceID == "" {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
if err != nil {
return nil, err
}
return &authtypes.Account{
Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
AppServiceID: appserviceID,
}, nil
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Account, error) {
var appserviceIDPtr sql.NullString
var acc authtypes.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
}
return nil, err
}
if appserviceIDPtr.Valid {
acc.AppServiceID = appserviceIDPtr.String
}
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
return &acc, nil
}
func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context,
) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
return
}

View file

@ -0,0 +1,139 @@
// Copyright 2017 Jan Christian Grünhage
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
)
const filterSchema = `
-- Stores data about filters
CREATE TABLE IF NOT EXISTS account_filter (
-- The filter
filter TEXT NOT NULL,
-- The ID
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The localpart of the Matrix user ID associated to this filter
localpart TEXT NOT NULL,
UNIQUE (id, localpart)
);
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
`
const selectFilterSQL = "" +
"SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2"
const selectFilterIDByContentSQL = "" +
"SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2"
const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
const selectLastInsertedFilterIDSQL = "" +
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
type filterStatements struct {
selectFilterStmt *sql.Stmt
selectLastInsertedFilterIDStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
}
func (s *filterStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(filterSchema)
if err != nil {
return
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return
}
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
return
}
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return
}
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return
}
return
}
func (s *filterStatements) selectFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
// Retrieve filter from database (stored as canonical JSON)
var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil {
return nil, err
}
// Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
}
return &filter, nil
}
func (s *filterStatements) insertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) {
var existingFilterID string
// Serialise json
filterJSON, err := json.Marshal(filter)
if err != nil {
return "", err
}
// Remove whitespaces and sort JSON data
// needed to prevent from inserting the same filter multiple times
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
if err != nil {
return "", err
}
// Check if filter already exists in the database using its localpart and content
//
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
return "", err
}
// If it does, return the existing ID
if existingFilterID != "" {
return existingFilterID, err
}
// Otherwise insert the filter and return the new ID
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
return "", err
}
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
if err := row.Scan(&filterID); err != nil {
return "", err
}
return
}

View file

@ -0,0 +1,131 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const membershipSchema = `
-- Stores data about users memberships to rooms.
CREATE TABLE IF NOT EXISTS account_memberships (
-- The Matrix user ID localpart for the member
localpart TEXT NOT NULL,
-- The room this user is a member of
room_id TEXT NOT NULL,
-- The ID of the join membership event
event_id TEXT NOT NULL,
-- A user can only be member of a room once
PRIMARY KEY (localpart, room_id),
UNIQUE (event_id)
);
`
const insertMembershipSQL = `
INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)
ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id
`
const selectMembershipsByLocalpartSQL = "" +
"SELECT room_id, event_id FROM account_memberships WHERE localpart = $1"
const selectMembershipInRoomByLocalpartSQL = "" +
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id IN ($1)"
type membershipStatements struct {
deleteMembershipsByEventIDsStmt *sql.Stmt
insertMembershipStmt *sql.Stmt
selectMembershipInRoomByLocalpartStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(membershipSchema)
if err != nil {
return
}
if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil {
return
}
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
return
}
if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil {
return
}
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return
}
return
}
func (s *membershipStatements) insertMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) (err error) {
stmt := txn.Stmt(s.insertMembershipStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
return
}
func (s *membershipStatements) deleteMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (err error) {
stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt)
_, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs))
return
}
func (s *membershipStatements) selectMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
membership := authtypes.Membership{Localpart: localpart, RoomID: roomID}
stmt := s.selectMembershipInRoomByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID)
return membership, err
}
func (s *membershipStatements) selectMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
stmt := s.selectMembershipsByLocalpartStmt
rows, err := stmt.QueryContext(ctx, localpart)
if err != nil {
return
}
memberships = []authtypes.Membership{}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var m authtypes.Membership
m.Localpart = localpart
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil {
return nil, err
}
memberships = append(memberships, m)
}
return
}

View file

@ -0,0 +1,107 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
`
const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
type profilesStatements struct {
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(profilesSchema)
if err != nil {
return
}
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
return
}
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
return
}
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
return
}
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
return
}
return
}
func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string,
) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
return
}
func (s *profilesStatements) selectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
}
return &profile, nil
}
func (s *profilesStatements) setAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
}
func (s *profilesStatements) setDisplayName(
ctx context.Context, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
}

View file

@ -0,0 +1,392 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/mattn/go-sqlite3"
)
// Database represents an account database
type Database struct {
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
accountDatas accountDataStatements
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}

View file

@ -0,0 +1,129 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const threepidSchema = `
-- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid (
-- The third party identifier
threepid TEXT NOT NULL,
-- The 3PID medium
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(threepidSchema)
if err != nil {
return
}
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
return
}
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
return
}
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
return
}
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
return
}
return
}
func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *threepidStatements) selectThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return
}
defer rows.Close() // nolint: errcheck
threepids = []authtypes.ThreePID{}
for rows.Next() {
var threepid string
var medium string
if err = rows.Scan(&threepid, &medium); err != nil {
return
}
threepids = append(threepids, authtypes.ThreePID{
Address: threepid,
Medium: medium,
})
}
return threepids, rows.Err()
}
func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := common.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
return
}
func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
return
}

View file

@ -1,392 +1,56 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts package accounts
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"net/url"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
) )
// Database represents an account database type Database interface {
type Database struct { common.PartitionStorer
db *sql.DB GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error)
common.PartitionOffsetStatements GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
accounts accountsStatements SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
profiles profilesStatements SetDisplayName(ctx context.Context, localpart string, displayName string) error
memberships membershipStatements CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
accountDatas accountDataStatements UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
threepids threepidStatements GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
filter filterStatements GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
serverName gomatrixserverlib.ServerName SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error)
} }
// NewDatabase creates a new accounts and profiles database func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { uri, err := url.Parse(dataSourceName)
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil { if err != nil {
return nil, err return postgres.NewDatabase(dataSourceName, serverName)
} }
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { switch uri.Scheme {
return nil, err case "postgres":
return postgres.NewDatabase(dataSourceName, serverName)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return postgres.NewDatabase(dataSourceName, serverName)
} }
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user. // a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use") var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}

View file

@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package devices package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/dendrite/common" "github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -80,6 +80,9 @@ const deleteDeviceSQL = "" +
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
@ -88,6 +91,7 @@ type devicesStatements struct {
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt
deleteDevicesStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -117,6 +121,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return return
} }
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -142,6 +149,7 @@ func (s *devicesStatements) insertDevice(
}, nil }, nil
} }
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
@ -150,6 +158,18 @@ func (s *devicesStatements) deleteDevice(
return err return err
} }
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
stmt := common.TxStmt(txn, s.deleteDevicesStmt)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
return err
}
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) error { ) error {
@ -206,6 +226,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
if err != nil { if err != nil {
return devices, err return devices, err
} }
defer rows.Close() // nolint: errcheck
for rows.Next() { for rows.Next() {
var dev authtypes.Device var dev authtypes.Device
@ -217,5 +238,5 @@ func (s *devicesStatements) selectDevicesByLocalpart(
devices = append(devices, dev) devices = append(devices, dev)
} }
return devices, nil return devices, rows.Err()
} }

View file

@ -0,0 +1,182 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -0,0 +1,243 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
)
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
UNIQUE (localpart, device_id)
);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
" VALUES ($1, $2, $3, $4, $5, $6)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
type devicesStatements struct {
db *sql.DB
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(devicesSchema)
if err != nil {
return
}
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
countStmt := common.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := common.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &authtypes.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
SessionID: sessionID,
}, nil
}
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := common.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
}
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := common.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
}
func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*authtypes.Device, error) {
var dev authtypes.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken
}
return &dev, err
}
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
var dev authtypes.Device
var created sql.NullInt64
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
}
return &dev, err
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
devices := []authtypes.Device{}
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return devices, err
}
for rows.Next() {
var dev authtypes.Device
err = rows.Scan(&dev.ID)
if err != nil {
return devices, err
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, nil
}

View file

@ -0,0 +1,184 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -1,167 +1,37 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices package devices
import ( import (
"context" "context"
"crypto/rand" "net/url"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// The length of generated device IDs type Database interface {
var deviceIDByteLength = 6 GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
// Database represents a device database. GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
type Database struct { CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
db *sql.DB UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
devices devicesStatements RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
RemoveAllDevices(ctx context.Context, localpart string) error
} }
// NewDatabase creates a new device database func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { uri, err := url.Parse(dataSourceName)
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil { if err != nil {
return "", err return postgres.NewDatabase(dataSourceName, serverName)
}
switch uri.Scheme {
case "postgres":
return postgres.NewDatabase(dataSourceName, serverName)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return postgres.NewDatabase(dataSourceName, serverName)
} }
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
} }

View file

@ -34,8 +34,8 @@ import (
// component. // component.
func SetupClientAPIComponent( func SetupClientAPIComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
deviceDB *devices.Database, deviceDB devices.Database,
accountsDB *accounts.Database, accountsDB accounts.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
keyRing *gomatrixserverlib.KeyRing, keyRing *gomatrixserverlib.KeyRing,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
@ -67,7 +67,7 @@ func SetupClientAPIComponent(
} }
routing.Setup( routing.Setup(
base.APIMux, *base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI, base.APIMux, base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI,
accountsDB, deviceDB, federation, *keyRing, userUpdateProducer, accountsDB, deviceDB, federation, *keyRing, userUpdateProducer,
syncProducer, typingProducer, transactionsCache, fedSenderAPI, syncProducer, typingProducer, transactionsCache, fedSenderAPI,
) )

View file

@ -31,7 +31,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer roomServerConsumer *common.ContinualConsumer
db *accounts.Database db accounts.Database
query api.RoomserverQueryAPI query api.RoomserverQueryAPI
serverName string serverName string
} }
@ -40,7 +40,7 @@ type OutputRoomEventConsumer struct {
func NewOutputRoomEventConsumer( func NewOutputRoomEventConsumer(
cfg *config.Dendrite, cfg *config.Dendrite,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
store *accounts.Database, store accounts.Database,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {

View file

@ -30,7 +30,7 @@ import (
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
func GetAccountData( func GetAccountData(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string, userID string, roomID string, dataType string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -62,7 +62,7 @@ func GetAccountData(
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
func SaveAccountData( func SaveAccountData(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {

View file

@ -102,7 +102,7 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
// AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID} // AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID}
func AuthFallback( func AuthFallback(
w http.ResponseWriter, req *http.Request, authType string, w http.ResponseWriter, req *http.Request, authType string,
cfg config.Dendrite, cfg *config.Dendrite,
) *util.JSONResponse { ) *util.JSONResponse {
sessionID := req.URL.Query().Get("session") sessionID := req.URL.Query().Get("session")
@ -130,7 +130,7 @@ func AuthFallback(
if req.Method == http.MethodGet { if req.Method == http.MethodGet {
// Handle Recaptcha // Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha { if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil { if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
return err return err
} }
@ -144,7 +144,7 @@ func AuthFallback(
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
// Handle Recaptcha // Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha { if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil { if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
return err return err
} }
@ -156,7 +156,7 @@ func AuthFallback(
} }
response := req.Form.Get("g-recaptcha-response") response := req.Form.Get("g-recaptcha-response")
if err := validateRecaptcha(&cfg, response, clientIP); err != nil { if err := validateRecaptcha(cfg, response, clientIP); err != nil {
util.GetLogger(req.Context()).Error(err) util.GetLogger(req.Context()).Error(err)
return err return err
} }

View file

@ -134,8 +134,8 @@ type fledglingEvent struct {
// CreateRoom implements /createRoom // CreateRoom implements /createRoom
func CreateRoom( func CreateRoom(
req *http.Request, device *authtypes.Device, req *http.Request, device *authtypes.Device,
cfg config.Dendrite, producer *producers.RoomserverProducer, cfg *config.Dendrite, producer *producers.RoomserverProducer,
accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
@ -148,8 +148,8 @@ func CreateRoom(
// nolint: gocyclo // nolint: gocyclo
func createRoom( func createRoom(
req *http.Request, device *authtypes.Device, req *http.Request, device *authtypes.Device,
cfg config.Dendrite, roomID string, producer *producers.RoomserverProducer, cfg *config.Dendrite, roomID string, producer *producers.RoomserverProducer,
accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
@ -344,7 +344,7 @@ func createRoom(
func buildEvent( func buildEvent(
builder *gomatrixserverlib.EventBuilder, builder *gomatrixserverlib.EventBuilder,
provider gomatrixserverlib.AuthEventProvider, provider gomatrixserverlib.AuthEventProvider,
cfg config.Dendrite, cfg *config.Dendrite,
evTime time.Time, evTime time.Time,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)

View file

@ -40,9 +40,13 @@ type deviceUpdateJSON struct {
DisplayName *string `json:"display_name"` DisplayName *string `json:"display_name"`
} }
type devicesDeleteJSON struct {
Devices []string `json:"devices"`
}
// GetDeviceByID handles /devices/{deviceID} // GetDeviceByID handles /devices/{deviceID}
func GetDeviceByID( func GetDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
@ -72,7 +76,7 @@ func GetDeviceByID(
// GetDevicesByLocalpart handles /devices // GetDevicesByLocalpart handles /devices
func GetDevicesByLocalpart( func GetDevicesByLocalpart(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -103,7 +107,7 @@ func GetDevicesByLocalpart(
// UpdateDeviceByID handles PUT on /devices/{deviceID} // UpdateDeviceByID handles PUT on /devices/{deviceID}
func UpdateDeviceByID( func UpdateDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
@ -146,3 +150,54 @@ func UpdateDeviceByID(
JSON: struct{}{}, JSON: struct{}{},
} }
} }
// DeleteDeviceById handles DELETE requests to /devices/{deviceId}
func DeleteDeviceById(
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return httputil.LogThenError(req, err)
}
ctx := req.Context()
defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
// DeleteDevices handles POST requests to /delete_devices
func DeleteDevices(
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return httputil.LogThenError(req, err)
}
ctx := req.Context()
payload := devicesDeleteJSON{}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
return httputil.LogThenError(req, err)
}
defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -27,7 +27,7 @@ import (
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
func GetFilter( func GetFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string, req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, filterID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
@ -63,7 +63,7 @@ type filterResponse struct {
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
func PutFilter( func PutFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -31,7 +31,7 @@ type getEventRequest struct {
device *authtypes.Device device *authtypes.Device
roomID string roomID string
eventID string eventID string
cfg config.Dendrite cfg *config.Dendrite
federation *gomatrixserverlib.FederationClient federation *gomatrixserverlib.FederationClient
keyRing gomatrixserverlib.KeyRing keyRing gomatrixserverlib.KeyRing
requestedEvent gomatrixserverlib.Event requestedEvent gomatrixserverlib.Event
@ -44,7 +44,7 @@ func GetEvent(
device *authtypes.Device, device *authtypes.Device,
roomID string, roomID string,
eventID string, eventID string,
cfg config.Dendrite, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.KeyRing, keyRing gomatrixserverlib.KeyRing,

View file

@ -39,13 +39,13 @@ func JoinRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
roomIDOrAlias string, roomIDOrAlias string,
cfg config.Dendrite, cfg *config.Dendrite,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
queryAPI roomserverAPI.RoomserverQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
keyRing gomatrixserverlib.KeyRing, keyRing gomatrixserverlib.KeyRing,
accountDB *accounts.Database, accountDB accounts.Database,
) util.JSONResponse { ) util.JSONResponse {
var content map[string]interface{} // must be a JSON object var content map[string]interface{} // must be a JSON object
if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil { if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil {
@ -98,7 +98,7 @@ type joinRoomReq struct {
evTime time.Time evTime time.Time
content map[string]interface{} content map[string]interface{}
userID string userID string
cfg config.Dendrite cfg *config.Dendrite
federation *gomatrixserverlib.FederationClient federation *gomatrixserverlib.FederationClient
producer *producers.RoomserverProducer producer *producers.RoomserverProducer
queryAPI roomserverAPI.RoomserverQueryAPI queryAPI roomserverAPI.RoomserverQueryAPI

View file

@ -70,8 +70,8 @@ func passwordLogin() loginFlows {
// Login implements GET and POST /login // Login implements GET and POST /login
func Login( func Login(
req *http.Request, accountDB *accounts.Database, deviceDB *devices.Database, req *http.Request, accountDB accounts.Database, deviceDB devices.Database,
cfg config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options
return util.JSONResponse{ return util.JSONResponse{
@ -153,7 +153,7 @@ func Login(
func getDevice( func getDevice(
ctx context.Context, ctx context.Context,
r passwordRequest, r passwordRequest,
deviceDB *devices.Database, deviceDB devices.Database,
acc *authtypes.Account, acc *authtypes.Account,
token string, token string,
) (dev *authtypes.Device, err error) { ) (dev *authtypes.Device, err error) {

View file

@ -26,7 +26,7 @@ import (
// Logout handles POST /logout // Logout handles POST /logout
func Logout( func Logout(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -45,7 +45,7 @@ func Logout(
// LogoutAll handles POST /logout/all // LogoutAll handles POST /logout/all
func LogoutAll( func LogoutAll(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {

View file

@ -40,8 +40,8 @@ var errMissingUserID = errors.New("'user_id' must be supplied")
// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite) // SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
// by building a m.room.member event then sending it to the room server // by building a m.room.member event then sending it to the room server
func SendMembership( func SendMembership(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
roomID string, membership string, cfg config.Dendrite, roomID string, membership string, cfg *config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
) util.JSONResponse { ) util.JSONResponse {
@ -116,10 +116,10 @@ func SendMembership(
func buildMembershipEvent( func buildMembershipEvent(
ctx context.Context, ctx context.Context,
body threepid.MembershipRequest, accountDB *accounts.Database, body threepid.MembershipRequest, accountDB accounts.Database,
device *authtypes.Device, device *authtypes.Device,
membership, roomID string, membership, roomID string,
cfg config.Dendrite, evTime time.Time, cfg *config.Dendrite, evTime time.Time,
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
stateKey, reason, err := getMembershipStateKey(body, device, membership) stateKey, reason, err := getMembershipStateKey(body, device, membership)
@ -165,8 +165,8 @@ func buildMembershipEvent(
func loadProfile( func loadProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
cfg config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID) _, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -214,9 +214,9 @@ func checkAndProcessThreepid(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
body *threepid.MembershipRequest, body *threepid.MembershipRequest,
cfg config.Dendrite, cfg *config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI,
accountDB *accounts.Database, accountDB accounts.Database,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
membership, roomID string, membership, roomID string,
evTime time.Time, evTime time.Time,

View file

@ -33,7 +33,7 @@ type response struct {
// GetMemberships implements GET /rooms/{roomId}/members // GetMemberships implements GET /rooms/{roomId}/members
func GetMemberships( func GetMemberships(
req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool,
_ config.Dendrite, _ *config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
queryReq := api.QueryMembershipsForRoomRequest{ queryReq := api.QueryMembershipsForRoomRequest{

View file

@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID} // GetProfile implements GET /profile/{userID}
func GetProfile( func GetProfile(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -64,7 +64,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url // GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL( func GetAvatarURL(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -90,7 +90,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url // SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL( func SetAvatarURL(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -170,7 +170,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname // GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName( func GetDisplayName(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -196,7 +196,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName( func SetDisplayName(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -279,7 +279,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically // Returns an error when something goes wrong or specifically
// common.ErrProfileNoExists when the profile doesn't exist. // common.ErrProfileNoExists when the profile doesn't exist.
func getProfile( func getProfile(
ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite, ctx context.Context, accountDB accounts.Database, cfg *config.Dendrite,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -343,7 +343,7 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
event, err := common.BuildEvent(ctx, &builder, *cfg, evTime, queryAPI, nil) event, err := common.BuildEvent(ctx, &builder, cfg, evTime, queryAPI, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -440,8 +440,8 @@ func validateApplicationService(
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
func Register( func Register(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
cfg *config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
@ -513,8 +513,8 @@ func handleGuestRegistration(
req *http.Request, req *http.Request,
r registerRequest, r registerRequest,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
//Generate numeric local part for guest user //Generate numeric local part for guest user
@ -570,8 +570,8 @@ func handleRegistrationFlow(
r registerRequest, r registerRequest,
sessionID string, sessionID string,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
// TODO: Shared secret registration (create new user scripts) // TODO: Shared secret registration (create new user scripts)
// TODO: Enable registration config flag // TODO: Enable registration config flag
@ -668,8 +668,8 @@ func handleApplicationServiceRegistration(
req *http.Request, req *http.Request,
r registerRequest, r registerRequest,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
// Check if we previously had issues extracting the access token from the // Check if we previously had issues extracting the access token from the
// request. // request.
@ -707,8 +707,8 @@ func checkAndCompleteFlow(
r registerRequest, r registerRequest,
sessionID string, sessionID string,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
) util.JSONResponse { ) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
@ -730,8 +730,8 @@ func checkAndCompleteFlow(
// LegacyRegister process register requests from the legacy v1 API // LegacyRegister process register requests from the legacy v1 API
func LegacyRegister( func LegacyRegister(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
cfg *config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var r legacyRegisterRequest var r legacyRegisterRequest
@ -814,8 +814,8 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
// not all // not all
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
username, password, appserviceID string, username, password, appserviceID string,
inhibitLogin common.WeakBoolean, inhibitLogin common.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
@ -991,8 +991,8 @@ type availableResponse struct {
// RegisterAvailable checks if the username is already taken or invalid. // RegisterAvailable checks if the username is already taken or invalid.
func RegisterAvailable( func RegisterAvailable(
req *http.Request, req *http.Request,
cfg config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB accounts.Database,
) util.JSONResponse { ) util.JSONResponse {
username := req.URL.Query().Get("username") username := req.URL.Query().Get("username")

View file

@ -40,7 +40,7 @@ func newTag() gomatrix.TagContent {
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags( func GetTags(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
device *authtypes.Device, device *authtypes.Device,
userID string, userID string,
roomID string, roomID string,
@ -77,7 +77,7 @@ func GetTags(
// the tag to the "map" and saving the new "map" to the DB // the tag to the "map" and saving the new "map" to the DB
func PutTag( func PutTag(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
device *authtypes.Device, device *authtypes.Device,
userID string, userID string,
roomID string, roomID string,
@ -134,7 +134,7 @@ func PutTag(
// the "map" and then saving the new "map" in the DB // the "map" and then saving the new "map" in the DB
func DeleteTag( func DeleteTag(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
device *authtypes.Device, device *authtypes.Device,
userID string, userID string,
roomID string, roomID string,
@ -203,7 +203,7 @@ func obtainSavedTags(
req *http.Request, req *http.Request,
userID string, userID string,
roomID string, roomID string,
accountDB *accounts.Database, accountDB accounts.Database,
) (string, *gomatrixserverlib.ClientEvent, error) { ) (string, *gomatrixserverlib.ClientEvent, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
@ -222,7 +222,7 @@ func saveTagData(
req *http.Request, req *http.Request,
localpart string, localpart string,
roomID string, roomID string,
accountDB *accounts.Database, accountDB accounts.Database,
Tag gomatrix.TagContent, Tag gomatrix.TagContent,
) error { ) error {
newTagData, err := json.Marshal(Tag) newTagData, err := json.Marshal(Tag)

View file

@ -47,13 +47,13 @@ const pathPrefixUnstable = "/_matrix/client/unstable"
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
apiMux *mux.Router, cfg config.Dendrite, apiMux *mux.Router, cfg *config.Dendrite,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
queryAPI roomserverAPI.RoomserverQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.KeyRing, keyRing gomatrixserverlib.KeyRing,
userUpdateProducer *producers.UserUpdateProducer, userUpdateProducer *producers.UserUpdateProducer,
@ -170,11 +170,11 @@ func Setup(
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
return Register(req, accountDB, deviceDB, &cfg) return Register(req, accountDB, deviceDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
return LegacyRegister(req, accountDB, deviceDB, &cfg) return LegacyRegister(req, accountDB, deviceDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
@ -187,7 +187,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return DirectoryRoom(req, vars["roomAlias"], federation, &cfg, aliasAPI, federationSender) return DirectoryRoom(req, vars["roomAlias"], federation, cfg, aliasAPI, federationSender)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -197,7 +197,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetLocalAlias(req, device, vars["roomAlias"], &cfg, aliasAPI) return SetLocalAlias(req, device, vars["roomAlias"], cfg, aliasAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -301,7 +301,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetProfile(req, accountDB, &cfg, vars["userID"], asAPI, federation) return GetProfile(req, accountDB, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -311,7 +311,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAvatarURL(req, accountDB, &cfg, vars["userID"], asAPI, federation) return GetAvatarURL(req, accountDB, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -321,7 +321,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI) return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
@ -333,7 +333,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetDisplayName(req, accountDB, &cfg, vars["userID"], asAPI, federation) return GetDisplayName(req, accountDB, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -343,7 +343,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI) return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
@ -503,6 +503,22 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
common.MakeAuthAPI("delete_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return DeleteDeviceById(req, deviceDB, device, vars["deviceID"])
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/delete_devices",
common.MakeAuthAPI("delete_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return DeleteDevices(req, deviceDB, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest // Stub implementations for sytest
r0mux.Handle("/events", r0mux.Handle("/events",
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {

View file

@ -43,7 +43,7 @@ func SendEvent(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
roomID, eventType string, txnID, stateKey *string, roomID, eventType string, txnID, stateKey *string,
cfg config.Dendrite, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
txnCache *transactions.Cache, txnCache *transactions.Cache,
@ -93,7 +93,7 @@ func generateSendEvent(
req *http.Request, req *http.Request,
device *authtypes.Device, device *authtypes.Device,
roomID, eventType string, stateKey *string, roomID, eventType string, stateKey *string,
cfg config.Dendrite, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) (*gomatrixserverlib.Event, *util.JSONResponse) { ) (*gomatrixserverlib.Event, *util.JSONResponse) {
// parse the incoming http request // parse the incoming http request

View file

@ -34,7 +34,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer // sends the typing events to client API typingProducer
func SendTyping( func SendTyping(
req *http.Request, device *authtypes.Device, roomID string, req *http.Request, device *authtypes.Device, roomID string,
userID string, accountDB *accounts.Database, userID string, accountDB accounts.Database,
typingProducer *producers.TypingServerProducer, typingProducer *producers.TypingServerProducer,
) util.JSONResponse { ) util.JSONResponse {
if device.UserID != userID { if device.UserID != userID {

View file

@ -39,7 +39,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements: // RequestEmailToken implements:
// POST /account/3pid/email/requestToken // POST /account/3pid/email/requestToken
// POST /register/email/requestToken // POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg config.Dendrite) util.JSONResponse { func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.Dendrite) util.JSONResponse {
var body threepid.EmailAssociationRequest var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -82,8 +82,8 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
// CheckAndSave3PIDAssociation implements POST /account/3pid // CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation( func CheckAndSave3PIDAssociation(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
cfg config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest var body threepid.EmailAssociationCheckRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
@ -142,7 +142,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid // GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accounts.Database, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -161,7 +161,7 @@ func GetAssociated3PIDs(
} }
// Forget3PID implements POST /account/3pid/delete // Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONResponse { func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
var body authtypes.ThreePID var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr

View file

@ -31,7 +31,7 @@ import (
// RequestTurnServer implements: // RequestTurnServer implements:
// GET /voip/turnServer // GET /voip/turnServer
func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.Dendrite) util.JSONResponse { func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg *config.Dendrite) util.JSONResponse {
turnConfig := cfg.TURN turnConfig := cfg.TURN
// TODO Guest Support // TODO Guest Support

View file

@ -86,8 +86,8 @@ var (
// can be emitted. // can be emitted.
func CheckAndProcessInvite( func CheckAndProcessInvite(
ctx context.Context, ctx context.Context,
device *authtypes.Device, body *MembershipRequest, cfg config.Dendrite, device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, db *accounts.Database, queryAPI api.RoomserverQueryAPI, db accounts.Database,
producer *producers.RoomserverProducer, membership string, roomID string, producer *producers.RoomserverProducer, membership string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed. // Returns an error if a check or a request failed.
func queryIDServer( func queryIDServer(
ctx context.Context, ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil { if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite( func queryIDServerStoreInvite(
ctx context.Context, ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) { ) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name // Retrieve the sender's profile to get their display name
@ -330,7 +330,7 @@ func checkIDServerSignatures(
func emit3PIDInviteEvent( func emit3PIDInviteEvent(
ctx context.Context, ctx context.Context,
body *MembershipRequest, res *idServerStoreInviteResponse, body *MembershipRequest, res *idServerStoreInviteResponse,
device *authtypes.Device, roomID string, cfg config.Dendrite, device *authtypes.Device, roomID string, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer,
evTime time.Time, evTime time.Time,
) error { ) error {

View file

@ -53,7 +53,7 @@ type Credentials struct {
// Returns an error if there was a problem sending the request or decoding the // Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status. // response, or if the identity server responded with a non-OK status.
func CreateSession( func CreateSession(
ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite, ctx context.Context, req EmailAssociationRequest, cfg *config.Dendrite,
) (string, error) { ) (string, error) {
if err := isTrusted(req.IDServer, cfg); err != nil { if err := isTrusted(req.IDServer, cfg); err != nil {
return "", err return "", err
@ -101,7 +101,7 @@ func CreateSession(
// Returns an error if there was a problem sending the request or decoding the // Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status. // response, or if the identity server responded with a non-OK status.
func CheckAssociation( func CheckAssociation(
ctx context.Context, creds Credentials, cfg config.Dendrite, ctx context.Context, creds Credentials, cfg *config.Dendrite,
) (bool, string, string, error) { ) (bool, string, string, error) {
if err := isTrusted(creds.IDServer, cfg); err != nil { if err := isTrusted(creds.IDServer, cfg); err != nil {
return false, "", "", err return false, "", "", err
@ -142,7 +142,7 @@ func CheckAssociation(
// identifier and a Matrix ID. // identifier and a Matrix ID.
// Returns an error if there was a problem sending the request or decoding the // Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status. // response, or if the identity server responded with a non-OK status.
func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) error { func PublishAssociation(creds Credentials, userID string, cfg *config.Dendrite) error {
if err := isTrusted(creds.IDServer, cfg); err != nil { if err := isTrusted(creds.IDServer, cfg); err != nil {
return err return err
} }
@ -177,7 +177,7 @@ func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) e
// isTrusted checks if a given identity server is part of the list of trusted // isTrusted checks if a given identity server is part of the list of trusted
// identity servers in the configuration file. // identity servers in the configuration file.
// Returns an error if the server isn't trusted. // Returns an error if the server isn't trusted.
func isTrusted(idServer string, cfg config.Dendrite) error { func isTrusted(idServer string, cfg *config.Dendrite) error {
for _, server := range cfg.Matrix.TrustedIDServers { for _, server := range cfg.Matrix.TrustedIDServers {
if idServer == server { if idServer == server {
return nil return nil

View file

@ -21,7 +21,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/Shopify/sarama" sarama "gopkg.in/Shopify/sarama.v1"
) )
const usage = `Usage: %s const usage = `Usage: %s

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"io" "io"
"net/http" "net/http"
"net/url"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -68,7 +69,13 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite {
logrus.WithError(err).Panicf("failed to start opentracing") logrus.WithError(err).Panicf("failed to start opentracing")
} }
kafkaConsumer, kafkaProducer := setupKafka(cfg) var kafkaConsumer sarama.Consumer
var kafkaProducer sarama.SyncProducer
if cfg.Kafka.UseNaffka {
kafkaConsumer, kafkaProducer = setupNaffka(cfg)
} else {
kafkaConsumer, kafkaProducer = setupKafka(cfg)
}
return &BaseDendrite{ return &BaseDendrite{
componentName: componentName, componentName: componentName,
@ -118,7 +125,7 @@ func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.Fede
// CreateDeviceDB creates a new instance of the device database. Should only be // CreateDeviceDB creates a new instance of the device database. Should only be
// called once per component. // called once per component.
func (b *BaseDendrite) CreateDeviceDB() *devices.Database { func (b *BaseDendrite) CreateDeviceDB() devices.Database {
db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName) db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to devices db") logrus.WithError(err).Panicf("failed to connect to devices db")
@ -129,7 +136,7 @@ func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() *accounts.Database { func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName) db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db") logrus.WithError(err).Panicf("failed to connect to accounts db")
@ -186,28 +193,8 @@ func (b *BaseDendrite) SetupAndServeHTTP(bindaddr string, listenaddr string) {
logrus.Infof("Stopped %s server on %s", b.componentName, addr) logrus.Infof("Stopped %s server on %s", b.componentName, addr)
} }
// setupKafka creates kafka consumer/producer pair from the config. Checks if // setupKafka creates kafka consumer/producer pair from the config.
// should use naffka.
func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) { func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
if cfg.Kafka.UseNaffka {
db, err := sql.Open("postgres", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err := naffka.NewPostgresqlDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
naff, err := naffka.New(naffkaDB)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka")
}
return naff, naff
}
consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil) consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
if err != nil { if err != nil {
logrus.WithError(err).Panic("failed to start kafka consumer") logrus.WithError(err).Panic("failed to start kafka consumer")
@ -220,3 +207,44 @@ func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
return consumer, producer return consumer, producer
} }
// setupNaffka creates kafka consumer/producer pair from the config.
func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
var err error
var db *sql.DB
var naffkaDB *naffka.DatabaseImpl
uri, err := url.Parse(string(cfg.Database.Naffka))
if err != nil || uri.Scheme == "file" {
db, err = sql.Open("sqlite3", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err = naffka.NewSqliteDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
} else {
db, err = sql.Open("postgres", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err = naffka.NewPostgresqlDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
}
if naffkaDB == nil {
panic("naffka connection string not understood")
}
naff, err := naffka.New(naffkaDB)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka")
}
return naff, naff
}

View file

@ -39,7 +39,7 @@ var ErrRoomNoExists = errors.New("Room does not exist")
// Returns an error if something else went wrong // Returns an error if something else went wrong
func BuildEvent( func BuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg config.Dendrite, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Dendrite, evTime time.Time,
queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse, queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes) err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes)

View file

@ -21,6 +21,7 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
"github.com/matrix-org/dendrite/common/keydb/postgres" "github.com/matrix-org/dendrite/common/keydb/postgres"
"github.com/matrix-org/dendrite/common/keydb/sqlite3"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -44,6 +45,8 @@ func NewDatabase(
switch uri.Scheme { switch uri.Scheme {
case "postgres": case "postgres":
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID) return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
default: default:
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID) return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
} }

View file

@ -117,7 +117,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
} }
} }
return results, nil return results, rows.Err()
} }
func (s *serverKeyStatements) upsertServerKeys( func (s *serverKeyStatements) upsertServerKeys(

View file

@ -0,0 +1,115 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"math"
"golang.org/x/crypto/ed25519"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
)
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers.
type Database struct {
statements serverKeyStatements
}
// NewDatabase prepares a new key database.
// It creates the necessary tables if they don't already exist.
// It prepares all the SQL statements that it will use.
// Returns an error if there was a problem talking to the database.
func NewDatabase(
dataSourceName string,
serverName gomatrixserverlib.ServerName,
serverKey ed25519.PublicKey,
serverKeyID gomatrixserverlib.KeyID,
) (*Database, error) {
db, err := sql.Open("sqlite3", dataSourceName)
if err != nil {
return nil, err
}
d := &Database{}
err = d.statements.prepare(db)
if err != nil {
return nil, err
}
// Store our own keys so that we don't end up making HTTP requests to find our
// own keys
index := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: serverName,
KeyID: serverKeyID,
}
value := gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: gomatrixserverlib.Base64String(serverKey),
},
ValidUntilTS: math.MaxUint64 >> 1,
ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
}
err = d.StoreKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{
index: value,
},
)
if err != nil {
return nil, err
}
return d, nil
}
// FetcherName implements KeyFetcher
func (d Database) FetcherName() string {
return "KeyDatabase"
}
// FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
return d.statements.bulkSelectServerKeys(ctx, requests)
}
// StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) StoreKeys(
ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
// TODO: Inserting all the keys within a single transaction may
// be more efficient since the transaction overhead can be quite
// high for a single insert statement.
var lastErr error
for request, keys := range keyMap {
if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
// Rather than returning immediately on error we try to insert the
// remaining keys.
// Since we are inserting the keys outside of a transaction it is
// possible for some of the inserts to succeed even though some
// of the inserts have failed.
// Ensuring that we always insert all the keys we can means that
// this behaviour won't depend on the iteration order of the map.
lastErr = err
}
}
return lastErr
}

View file

@ -0,0 +1,142 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
const serverKeysSchema = `
-- A cache of signing keys downloaded from remote servers.
CREATE TABLE IF NOT EXISTS keydb_server_keys (
-- The name of the matrix server the key is for.
server_name TEXT NOT NULL,
-- The ID of the server key.
server_key_id TEXT NOT NULL,
-- Combined server name and key ID separated by the ASCII unit separator
-- to make it easier to run bulk queries.
server_name_and_key_id TEXT NOT NULL,
-- When the key is valid until as a millisecond timestamp.
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
valid_until_ts BIGINT NOT NULL,
-- When the key expired as a millisecond timestamp.
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
expired_ts BIGINT NOT NULL,
-- The base64-encoded public key.
server_key TEXT NOT NULL,
UNIQUE (server_name, server_key_id)
);
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
`
const bulkSelectServerKeysSQL = "" +
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
" server_key FROM keydb_server_keys" +
" WHERE server_name_and_key_id IN ($1)"
const upsertServerKeysSQL = "" +
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (server_name, server_key_id)" +
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverKeyStatements struct {
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
}
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(serverKeysSchema)
if err != nil {
return
}
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
return
}
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
return
}
return
}
func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string
for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
}
stmt := s.bulkSelectServerKeysStmt
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() {
var serverName string
var keyID string
var key string
var validUntilTS int64
var expiredTS int64
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err
}
r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key)
if err != nil {
return nil, err
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return results, nil
}
func (s *serverKeyStatements) upsertServerKeys(
ctx context.Context,
request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult,
) error {
_, err := s.upsertServerKeysStmt.ExecContext(
ctx,
string(request.ServerName),
string(request.KeyID),
nameAndKeyID(request),
key.ValidUntilTS,
key.ExpiredTS,
key.Key.Encode(),
)
return err
}
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
return string(request.ServerName) + "\x1F" + string(request.KeyID)
}

View file

@ -29,7 +29,7 @@ CREATE TABLE IF NOT EXISTS ${prefix}_partition_offsets (
partition INTEGER NOT NULL, partition INTEGER NOT NULL,
-- The 64-bit offset. -- The 64-bit offset.
partition_offset BIGINT NOT NULL, partition_offset BIGINT NOT NULL,
CONSTRAINT ${prefix}_topic_partition_unique UNIQUE (topic, partition) UNIQUE (topic, partition)
); );
` `
@ -38,7 +38,7 @@ const selectPartitionOffsetsSQL = "" +
const upsertPartitionOffsetsSQL = "" + const upsertPartitionOffsetsSQL = "" +
"INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" + "INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" +
" ON CONFLICT ON CONSTRAINT ${prefix}_topic_partition_unique" + " ON CONFLICT (topic, partition)" +
" DO UPDATE SET partition_offset = $3" " DO UPDATE SET partition_offset = $3"
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table. // PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
@ -99,7 +99,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
} }
results = append(results, offset) results = append(results, offset)
} }
return results, nil return results, rows.Err()
} }
// UpsertPartitionOffset updates or inserts the partition offset for the given topic. // UpsertPartitionOffset updates or inserts the partition offset for the given topic.

View file

@ -16,6 +16,7 @@ package common
import ( import (
"database/sql" "database/sql"
"fmt"
"github.com/lib/pq" "github.com/lib/pq"
) )
@ -30,11 +31,13 @@ type Transaction interface {
// EndTransaction ends a transaction. // EndTransaction ends a transaction.
// If the transaction succeeded then it is committed, otherwise it is rolledback. // If the transaction succeeded then it is committed, otherwise it is rolledback.
func EndTransaction(txn Transaction, succeeded *bool) { // You MUST check the error returned from this function to be sure that the transaction
// was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
func EndTransaction(txn Transaction, succeeded *bool) error {
if *succeeded { if *succeeded {
txn.Commit() // nolint: errcheck return txn.Commit() // nolint: errcheck
} else { } else {
txn.Rollback() // nolint: errcheck return txn.Rollback() // nolint: errcheck
} }
} }
@ -47,7 +50,12 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
return return
} }
succeeded := false succeeded := false
defer EndTransaction(txn, &succeeded) defer func() {
err2 := EndTransaction(txn, &succeeded)
if err == nil && err2 != nil { // failed to commit/rollback
err = err2
}
}()
err = fn(txn) err = fn(txn)
if err != nil { if err != nil {
@ -74,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error) pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505" return ok && pqErr.Code == "23505"
} }
// Hack of the century
func QueryVariadic(count int) string {
return QueryVariadicOffset(count, 0)
}
func QueryVariadicOffset(count, offset int) string {
str := "("
for i := 0; i < count; i++ {
str += fmt.Sprintf("$%d", i+offset+1)
if i < (count - 1) {
str += ", "
}
}
str += ")"
return str
}

View file

@ -111,6 +111,7 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
// Bind to the same address as the listen address // Bind to the same address as the listen address
// All microservices are run on the same host in testing // All microservices are run on the same host in testing
cfg.Bind.ClientAPI = cfg.Listen.ClientAPI cfg.Bind.ClientAPI = cfg.Listen.ClientAPI
cfg.Bind.AppServiceAPI = cfg.Listen.AppServiceAPI
cfg.Bind.FederationAPI = cfg.Listen.FederationAPI cfg.Bind.FederationAPI = cfg.Listen.FederationAPI
cfg.Bind.MediaAPI = cfg.Listen.MediaAPI cfg.Bind.MediaAPI = cfg.Listen.MediaAPI
cfg.Bind.RoomServer = cfg.Listen.RoomServer cfg.Bind.RoomServer = cfg.Listen.RoomServer

View file

@ -1,9 +1,13 @@
<<<<<<< HEAD
FROM docker.io/golang:1.13.7-alpine3.11
=======
FROM docker.io/golang:1.13.6-alpine FROM docker.io/golang:1.13.6-alpine
>>>>>>> master
RUN mkdir /build RUN mkdir /build
WORKDIR /build WORKDIR /build
RUN apk --update --no-cache add openssl bash git RUN apk --update --no-cache add openssl bash git build-base
CMD ["bash", "docker/build.sh"] CMD ["bash", "docker/build.sh"]

View file

@ -1,13 +1,21 @@
version: "3.4" version: "3.4"
services: services:
riot:
image: vectorim/riot-web
networks:
- internal
ports:
- "8500:80"
monolith: monolith:
container_name: dendrite_monolith container_name: dendrite_monolith
hostname: monolith hostname: monolith
entrypoint: ["bash", "./docker/services/monolith.sh"] entrypoint: ["bash", "./docker/services/monolith.sh", "--config", "/etc/dendrite/dendrite.yaml"]
build: ./ build: ./
volumes: volumes:
- ..:/build - ..:/build
- ./build/bin:/build/bin - ./build/bin:/build/bin
- ../cfg:/etc/dendrite
networks: networks:
- internal - internal
depends_on: depends_on:

View file

@ -44,6 +44,7 @@ args:
user: dendrite user: dendrite
database: dendrite database: dendrite
host: 127.0.0.1 host: 127.0.0.1
sslmode: disable
type: pg type: pg
EOF EOF
``` ```

View file

@ -32,8 +32,8 @@ import (
// FederationAPI component. // FederationAPI component.
func SetupFederationAPIComponent( func SetupFederationAPIComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
accountsDB *accounts.Database, accountsDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
keyRing *gomatrixserverlib.KeyRing, keyRing *gomatrixserverlib.KeyRing,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
@ -45,8 +45,8 @@ func SetupFederationAPIComponent(
roomserverProducer := producers.NewRoomserverProducer(inputAPI) roomserverProducer := producers.NewRoomserverProducer(inputAPI)
routing.Setup( routing.Setup(
base.APIMux, *base.Cfg, queryAPI, aliasAPI, asAPI, base.APIMux, base.Cfg, queryAPI, aliasAPI, asAPI,
roomserverProducer, federationSenderAPI, *keyRing, federation, accountsDB, roomserverProducer, federationSenderAPI, *keyRing,
deviceDB, federation, accountsDB, deviceDB,
) )
} }

View file

@ -34,7 +34,7 @@ func Backfill(
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
query api.RoomserverQueryAPI, query api.RoomserverQueryAPI,
roomID string, roomID string,
cfg config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var res api.QueryBackfillResponse var res api.QueryBackfillResponse
var eIDs []string var eIDs []string

View file

@ -30,7 +30,7 @@ type userDevicesResponse struct {
// GetUserDevices for the given user id // GetUserDevices for the given user id
func GetUserDevices( func GetUserDevices(
req *http.Request, req *http.Request,
deviceDB *devices.Database, deviceDB devices.Database,
userID string, userID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, err := userutil.ParseUsernameParam(userID, nil) localpart, err := userutil.ParseUsernameParam(userID, nil)

View file

@ -32,7 +32,7 @@ func Invite(
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
roomID string, roomID string,
eventID string, eventID string,
cfg config.Dendrite, cfg *config.Dendrite,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -33,7 +33,7 @@ import (
func MakeJoin( func MakeJoin(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite, cfg *config.Dendrite,
query api.RoomserverQueryAPI, query api.RoomserverQueryAPI,
roomID, userID string, roomID, userID string,
) util.JSONResponse { ) util.JSONResponse {
@ -97,7 +97,7 @@ func MakeJoin(
func SendJoin( func SendJoin(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite, cfg *config.Dendrite,
query api.RoomserverQueryAPI, query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,

View file

@ -27,7 +27,7 @@ import (
// LocalKeys returns the local keys for the server. // LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg config.Dendrite) util.JSONResponse { func LocalKeys(cfg *config.Dendrite) util.JSONResponse {
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -35,7 +35,7 @@ func LocalKeys(cfg config.Dendrite) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: keys} return util.JSONResponse{Code: http.StatusOK, JSON: keys}
} }
func localKeys(cfg config.Dendrite, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { func localKeys(cfg *config.Dendrite, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
var keys gomatrixserverlib.ServerKeys var keys gomatrixserverlib.ServerKeys
keys.ServerName = cfg.Matrix.ServerName keys.ServerName = cfg.Matrix.ServerName

View file

@ -31,7 +31,7 @@ import (
func MakeLeave( func MakeLeave(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite, cfg *config.Dendrite,
query api.RoomserverQueryAPI, query api.RoomserverQueryAPI,
roomID, userID string, roomID, userID string,
) util.JSONResponse { ) util.JSONResponse {
@ -95,7 +95,7 @@ func MakeLeave(
func SendLeave( func SendLeave(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite, cfg *config.Dendrite,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,
roomID, eventID string, roomID, eventID string,

View file

@ -30,8 +30,8 @@ import (
// GetProfile implements GET /_matrix/federation/v1/query/profile // GetProfile implements GET /_matrix/federation/v1/query/profile
func GetProfile( func GetProfile(
httpReq *http.Request, httpReq *http.Request,
accountDB *accounts.Database, accountDB accounts.Database,
cfg config.Dendrite, cfg *config.Dendrite,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field") userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field")

View file

@ -32,7 +32,7 @@ import (
func RoomAliasToID( func RoomAliasToID(
httpReq *http.Request, httpReq *http.Request,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
cfg config.Dendrite, cfg *config.Dendrite,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
senderAPI federationSenderAPI.FederationSenderQueryAPI, senderAPI federationSenderAPI.FederationSenderQueryAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -43,7 +43,7 @@ const (
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
apiMux *mux.Router, apiMux *mux.Router,
cfg config.Dendrite, cfg *config.Dendrite,
query roomserverAPI.RoomserverQueryAPI, query roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
@ -51,8 +51,8 @@ func Setup(
federationSenderAPI federationSenderAPI.FederationSenderQueryAPI, federationSenderAPI federationSenderAPI.FederationSenderQueryAPI,
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
accountDB *accounts.Database, accountDB accounts.Database,
deviceDB *devices.Database, deviceDB devices.Database,
) { ) {
v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter() v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter()
v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter() v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter()

View file

@ -34,7 +34,7 @@ func Send(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
txnID gomatrixserverlib.TransactionID, txnID gomatrixserverlib.TransactionID,
cfg config.Dendrite, cfg *config.Dendrite,
query api.RoomserverQueryAPI, query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,

View file

@ -59,9 +59,9 @@ var (
// CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind // CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind
func CreateInvitesFrom3PIDInvites( func CreateInvitesFrom3PIDInvites(
req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI, req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI,
asAPI appserviceAPI.AppServiceQueryAPI, cfg config.Dendrite, asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite,
producer *producers.RoomserverProducer, federation *gomatrixserverlib.FederationClient, producer *producers.RoomserverProducer, federation *gomatrixserverlib.FederationClient,
accountDB *accounts.Database, accountDB accounts.Database,
) util.JSONResponse { ) util.JSONResponse {
var body invites var body invites
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
@ -98,7 +98,7 @@ func ExchangeThirdPartyInvite(
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
roomID string, roomID string,
queryAPI roomserverAPI.RoomserverQueryAPI, queryAPI roomserverAPI.RoomserverQueryAPI,
cfg config.Dendrite, cfg *config.Dendrite,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
) util.JSONResponse { ) util.JSONResponse {
@ -172,9 +172,9 @@ func ExchangeThirdPartyInvite(
// necessary data to do so. // necessary data to do so.
func createInviteFrom3PIDInvite( func createInviteFrom3PIDInvite(
ctx context.Context, queryAPI roomserverAPI.RoomserverQueryAPI, ctx context.Context, queryAPI roomserverAPI.RoomserverQueryAPI,
asAPI appserviceAPI.AppServiceQueryAPI, cfg config.Dendrite, asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite,
inv invite, federation *gomatrixserverlib.FederationClient, inv invite, federation *gomatrixserverlib.FederationClient,
accountDB *accounts.Database, accountDB accounts.Database,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
_, server, err := gomatrixserverlib.SplitID('@', inv.MXID) _, server, err := gomatrixserverlib.SplitID('@', inv.MXID)
if err != nil { if err != nil {
@ -230,7 +230,7 @@ func createInviteFrom3PIDInvite(
func buildMembershipEvent( func buildMembershipEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, queryAPI roomserverAPI.RoomserverQueryAPI, builder *gomatrixserverlib.EventBuilder, queryAPI roomserverAPI.RoomserverQueryAPI,
cfg config.Dendrite, cfg *config.Dendrite,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil { if err != nil {
@ -290,7 +290,7 @@ func buildMembershipEvent(
// them responded with an error. // them responded with an error.
func sendToRemoteServer( func sendToRemoteServer(
ctx context.Context, inv invite, ctx context.Context, inv invite,
federation *gomatrixserverlib.FederationClient, _ config.Dendrite, federation *gomatrixserverlib.FederationClient, _ *config.Dendrite,
builder gomatrixserverlib.EventBuilder, builder gomatrixserverlib.EventBuilder,
) (err error) { ) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2) remoteServers := make([]gomatrixserverlib.ServerName, 2)

View file

@ -132,5 +132,5 @@ func joinedHostsFromStmt(
}) })
} }
return result, nil return result, rows.Err()
} }

View file

@ -87,7 +87,7 @@ func (d *Database) UpdateRoom(
return nil return nil
} }
if lastSentEventID != oldEventID { if lastSentEventID != "" && lastSentEventID != oldEventID {
return types.EventIDMismatchError{ return types.EventIDMismatchError{
DatabaseID: lastSentEventID, RoomServerID: oldEventID, DatabaseID: lastSentEventID, RoomServerID: oldEventID,
} }

View file

@ -0,0 +1,139 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/gomatrixserverlib"
)
const joinedHostsSchema = `
-- The joined_hosts table stores a list of m.room.member event ids in the
-- current state for each room where the membership is "join".
-- There will be an entry for every user that is joined to the room.
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
-- The string ID of the room.
room_id TEXT NOT NULL,
-- The event ID of the m.room.member join event.
event_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
ON federationsender_joined_hosts (event_id);
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
ON federationsender_joined_hosts (room_id)
`
const insertJoinedHostsSQL = "" +
"INSERT INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
" VALUES ($1, $2, $3)"
const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1"
type joinedHostsStatements struct {
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
}
func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(joinedHostsSchema)
if err != nil {
return
}
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
return
}
func (s *joinedHostsStatements) insertJoinedHosts(
ctx context.Context,
txn *sql.Tx,
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error {
stmt := common.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err
}
func (s *joinedHostsStatements) deleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
for _, eventID := range eventIDs {
stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt)
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
return err
}
}
return nil
}
func (s *joinedHostsStatements) selectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
stmt := common.TxStmt(txn, s.selectJoinedHostsStmt)
return joinedHostsFromStmt(ctx, stmt, roomID)
}
func (s *joinedHostsStatements) selectJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error) {
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
}
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
var result []types.JoinedHost
for rows.Next() {
var eventID, serverName string
if err = rows.Scan(&eventID, &serverName); err != nil {
return nil, err
}
result = append(result, types.JoinedHost{
MemberEventID: eventID,
ServerName: gomatrixserverlib.ServerName(serverName),
})
}
return result, nil
}

View file

@ -0,0 +1,101 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
)
const roomSchema = `
CREATE TABLE IF NOT EXISTS federationsender_rooms (
-- The string ID of the room
room_id TEXT PRIMARY KEY,
-- The most recent event state by the room server.
-- We can use this to tell if our view of the room state has become
-- desynchronised.
last_event_id TEXT NOT NULL
);`
const insertRoomSQL = "" +
"INSERT INTO federationsender_rooms (room_id, last_event_id) VALUES ($1, '')" +
" ON CONFLICT DO NOTHING"
const selectRoomForUpdateSQL = "" +
"SELECT last_event_id FROM federationsender_rooms WHERE room_id = $1"
const updateRoomSQL = "" +
"UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1"
type roomStatements struct {
insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt
}
func (s *roomStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(roomSchema)
if err != nil {
return
}
if s.insertRoomStmt, err = db.Prepare(insertRoomSQL); err != nil {
return
}
if s.selectRoomForUpdateStmt, err = db.Prepare(selectRoomForUpdateSQL); err != nil {
return
}
if s.updateRoomStmt, err = db.Prepare(updateRoomSQL); err != nil {
return
}
return
}
// insertRoom inserts the room if it didn't already exist.
// If the room didn't exist then last_event_id is set to the empty string.
func (s *roomStatements) insertRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
// The row must already exist in the table. Callers can ensure that the row
// exists by calling insertRoom first.
func (s *roomStatements) selectRoomForUpdate(
ctx context.Context, txn *sql.Tx, roomID string,
) (string, error) {
var lastEventID string
stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
if err != nil {
return "", err
}
return lastEventID, nil
}
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
// have already been called earlier within the transaction.
func (s *roomStatements) updateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error {
stmt := common.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err
}

View file

@ -0,0 +1,124 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/types"
)
// Database stores information needed by the federation sender
type Database struct {
joinedHostsStatements
roomStatements
common.PartitionOffsetStatements
db *sql.DB
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sql.Open("sqlite3", dataSourceName); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
var err error
if err = d.joinedHostsStatements.prepare(d.db); err != nil {
return err
}
if err = d.roomStatements.prepare(d.db); err != nil {
return err
}
return d.PartitionOffsetStatements.Prepare(d.db, "federationsender")
}
// UpdateRoom updates the joined hosts for a room and returns what the joined
// hosts were before the update, or nil if this was a duplicate message.
// This is called when we receive a message from kafka, so we pass in
// oldEventID and newEventID to check that we haven't missed any messages or
// this isn't a duplicate message.
func (d *Database) UpdateRoom(
ctx context.Context,
roomID, oldEventID, newEventID string,
addHosts []types.JoinedHost,
removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
err = d.insertRoom(ctx, txn, roomID)
if err != nil {
return err
}
lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID)
if err != nil {
return err
}
if lastSentEventID == newEventID {
// We've handled this message before, so let's just ignore it.
// We can only get a duplicate for the last message we processed,
// so its enough just to compare the newEventID with lastSentEventID
return nil
}
if lastSentEventID != "" && lastSentEventID != oldEventID {
return types.EventIDMismatchError{
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
}
}
joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID)
if err != nil {
return err
}
for _, add := range addHosts {
err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
if err != nil {
return err
}
}
if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil {
return err
}
return d.updateRoom(ctx, txn, roomID, newEventID)
})
return
}
// GetJoinedHosts returns the currently joined hosts for room,
// as known to federationserver.
// Returns an error if something goes wrong.
func (d *Database) GetJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error) {
return d.selectJoinedHosts(ctx, roomID)
}

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/storage/postgres" "github.com/matrix-org/dendrite/federationsender/storage/postgres"
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3"
"github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/federationsender/types"
) )
@ -36,6 +37,8 @@ func NewDatabase(dataSourceName string) (Database, error) {
return postgres.NewDatabase(dataSourceName) return postgres.NewDatabase(dataSourceName)
} }
switch uri.Scheme { switch uri.Scheme {
case "file":
return sqlite3.NewDatabase(dataSourceName)
case "postgres": case "postgres":
return postgres.NewDatabase(dataSourceName) return postgres.NewDatabase(dataSourceName)
default: default:

22
go.mod
View file

@ -1,32 +1,30 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
require ( require (
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3 github.com/DataDog/zstd v1.4.4 // indirect
github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42 // indirect github.com/eapache/go-resiliency v1.2.0 // indirect
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect
github.com/eapache/queue v1.1.0 // indirect github.com/eapache/queue v1.1.0 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/frankban/quicktest v1.7.2 // indirect
github.com/golang/snappy v0.0.1 // indirect
github.com/gorilla/mux v1.7.3 github.com/gorilla/mux v1.7.3
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 // indirect
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/lib/pq v1.2.0 github.com/lib/pq v1.2.0
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5
github.com/mattn/go-sqlite3 v2.0.2+incompatible
github.com/miekg/dns v1.1.12 // indirect github.com/miekg/dns v1.1.12 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5
github.com/opentracing/opentracing-go v1.0.2 github.com/opentracing/opentracing-go v1.0.2
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac // indirect github.com/pierrec/lz4 v2.4.1+incompatible // indirect
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897 // indirect
github.com/pkg/errors v0.8.1 github.com/pkg/errors v0.8.1
github.com/prometheus/client_golang v1.2.1 github.com/prometheus/client_golang v1.2.1
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 // indirect github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 // indirect
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0 // indirect github.com/stretchr/testify v1.4.0 // indirect
github.com/uber-go/atomic v1.3.0 // indirect github.com/uber-go/atomic v1.3.0 // indirect
@ -36,7 +34,7 @@ require (
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 // indirect golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 // indirect
gopkg.in/Shopify/sarama.v1 v1.11.0 gopkg.in/Shopify/sarama.v1 v1.20.1
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/h2non/bimg.v1 v1.0.18 gopkg.in/h2non/bimg.v1 v1.0.18
gopkg.in/yaml.v2 v2.2.2 gopkg.in/yaml.v2 v2.2.2

43
go.sum
View file

@ -1,5 +1,5 @@
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3 h1:j6BAEHYn1kUyW2j7kY0mOJ/R8A0qWwXpvUAEHGemm/g= github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE=
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
@ -18,13 +18,15 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42 h1:f8ERmXYuaC+kCSv2w+y3rBK/oVu6If4DEm3jywJJ0hc= github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q=
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934 h1:oGLoaVIefp3tiOgi7+KInR/nNPvEpPM6GFo+El7fd14= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw=
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k=
github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk=
github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
@ -35,11 +37,13 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
@ -48,8 +52,6 @@ github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplb
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg=
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@ -69,10 +71,12 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 h1:kmRjpmFOenVpOaV/DRlo9p6z/IbOKlUC+hhKsAAh8Qg= github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 h1:kmRjpmFOenVpOaV/DRlo9p6z/IbOKlUC+hhKsAAh8Qg=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI= github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI=
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A= github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1 h1:osLoFdOy+ChQqVUn2PeTDETFftVkl4w9t/OW18g3lnk=
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A= github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8= github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8=
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA= github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA=
github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U=
github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0= github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0=
@ -90,10 +94,8 @@ github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/R
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg= github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac h1:tKcxwAA5OHUQjL6sWsuCIcP9OnzN+RwKfvomtIOsfy8= github.com/pierrec/lz4 v2.4.1+incompatible h1:mFe7ttWaflA46Mhqh+jUfjp2qTbPYxLB2/OyBppH9dg=
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4 v2.4.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897 h1:jp3jc/PyyTrTKjJJ6rWnhTbmo7tGgBFyG9AL5FIrO1I=
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -114,8 +116,8 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.0.5 h1:3+auTFlqw+ZaQYJARz6ArODtkaIwtvBTx3N2NehQlL8= github.com/prometheus/procfs v0.0.5 h1:3+auTFlqw+ZaQYJARz6ArODtkaIwtvBTx3N2NehQlL8=
github.com/prometheus/procfs v0.0.5/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ= github.com/prometheus/procfs v0.0.5/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ=
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 h1:gwcdIpH6NU2iF8CmcqD+CP6+1CkRBOhHaPR+iu6raBY= github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 h1:dY6ETXrvDG7Sa4vE8ZQG4yqWg6UnOcbqTAahkV813vQ=
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME=
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
@ -172,8 +174,8 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY= golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY=
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0= gopkg.in/Shopify/sarama.v1 v1.20.1 h1:Gi09A3fJXm0Jgt8kuKZ8YK+r60GfYn7MQuEmI3oq6hE=
gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc= gopkg.in/Shopify/sarama.v1 v1.20.1/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@ -183,6 +185,7 @@ gopkg.in/h2non/bimg.v1 v1.0.18 h1:qn6/RpBHt+7WQqoBcK+aF2puc6nC78eZj5LexxoalT4=
gopkg.in/h2non/bimg.v1 v1.0.18/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So= gopkg.in/h2non/bimg.v1 v1.0.18/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So=
gopkg.in/h2non/gock.v1 v1.0.14 h1:fTeu9fcUvSnLNacYvYI54h+1/XEteDyHvrVCZEEEYNM= gopkg.in/h2non/gock.v1 v1.0.14 h1:fTeu9fcUvSnLNacYvYI54h+1/XEteDyHvrVCZEEEYNM=
gopkg.in/h2non/gock.v1 v1.0.14/go.mod h1:sX4zAkdYX1TRGJ2JY156cFspQn4yRWn6p9EMdODlynE= gopkg.in/h2non/gock.v1 v1.0.14/go.mod h1:sX4zAkdYX1TRGJ2JY156cFspQn4yRWn6p9EMdODlynE=
gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI=
gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o= gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=

View file

@ -27,7 +27,7 @@ import (
// component. // component.
func SetupMediaAPIComponent( func SetupMediaAPIComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
deviceDB *devices.Database, deviceDB devices.Database,
) { ) {
mediaDB, err := storage.Open(string(base.Cfg.Database.MediaAPI)) mediaDB, err := storage.Open(string(base.Cfg.Database.MediaAPI))
if err != nil { if err != nil {

View file

@ -44,7 +44,7 @@ func Setup(
apiMux *mux.Router, apiMux *mux.Router,
cfg *config.Dendrite, cfg *config.Dendrite,
db storage.Database, db storage.Database,
deviceDB *devices.Database, deviceDB devices.Database,
client *gomatrixserverlib.Client, client *gomatrixserverlib.Client,
) { ) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()

View file

@ -144,6 +144,7 @@ func (s *thumbnailStatements) selectThumbnails(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck
var thumbnails []*types.ThumbnailMetadata var thumbnails []*types.ThumbnailMetadata
for rows.Next() { for rows.Next() {
@ -167,5 +168,5 @@ func (s *thumbnailStatements) selectThumbnails(
thumbnails = append(thumbnails, &thumbnailMetadata) thumbnails = append(thumbnails, &thumbnailMetadata)
} }
return thumbnails, err return thumbnails, rows.Err()
} }

View file

@ -28,7 +28,7 @@ import (
// component. // component.
func SetupPublicRoomsAPIComponent( func SetupPublicRoomsAPIComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
deviceDB *devices.Database, deviceDB devices.Database,
rsQueryAPI roomserverAPI.RoomserverQueryAPI, rsQueryAPI roomserverAPI.RoomserverQueryAPI,
) { ) {
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI)) publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))

View file

@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
// Due to Setup being used to call many other functions, a gocyclo nolint is // Due to Setup being used to call many other functions, a gocyclo nolint is
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB storage.Database) { func Setup(apiMux *mux.Router, deviceDB devices.Database, publicRoomsDB storage.Database) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{ authData := auth.Data{

View file

@ -203,6 +203,7 @@ func (s *publicRoomsStatements) selectPublicRooms(
if err != nil { if err != nil {
return []types.PublicRoom{}, nil return []types.PublicRoom{}, nil
} }
defer rows.Close() // nolint: errcheck
rooms := []types.PublicRoom{} rooms := []types.PublicRoom{}
for rows.Next() { for rows.Next() {
@ -222,7 +223,7 @@ func (s *publicRoomsStatements) selectPublicRooms(
rooms = append(rooms, r) rooms = append(rooms, r)
} }
return rooms, nil return rooms, rows.Err()
} }
func (s *publicRoomsStatements) selectRoomVisibility( func (s *publicRoomsStatements) selectRoomVisibility(

View file

@ -0,0 +1,36 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -0,0 +1,277 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/publicroomsapi/types"
)
var editableAttributes = []string{
"aliases",
"canonical_alias",
"name",
"topic",
"world_readable",
"guest_can_join",
"avatar_url",
"visibility",
}
const publicRoomsSchema = `
-- Stores all of the rooms with data needed to create the server's room directory
CREATE TABLE IF NOT EXISTS publicroomsapi_public_rooms(
-- The room's ID
room_id TEXT NOT NULL PRIMARY KEY,
-- Number of joined members in the room
joined_members INTEGER NOT NULL DEFAULT 0,
-- Aliases of the room (empty array if none)
aliases TEXT[] NOT NULL DEFAULT '{}'::TEXT[],
-- Canonical alias of the room (empty string if none)
canonical_alias TEXT NOT NULL DEFAULT '',
-- Name of the room (empty string if none)
name TEXT NOT NULL DEFAULT '',
-- Topic of the room (empty string if none)
topic TEXT NOT NULL DEFAULT '',
-- Is the room world readable?
world_readable BOOLEAN NOT NULL DEFAULT false,
-- Can guest join the room?
guest_can_join BOOLEAN NOT NULL DEFAULT false,
-- URL of the room avatar (empty string if none)
avatar_url TEXT NOT NULL DEFAULT '',
-- Visibility of the room: true means the room is publicly visible, false
-- means the room is private
visibility BOOLEAN NOT NULL DEFAULT false
);
`
const countPublicRoomsSQL = "" +
"SELECT COUNT(*) FROM publicroomsapi_public_rooms" +
" WHERE visibility = true"
const selectPublicRoomsSQL = "" +
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
" FROM publicroomsapi_public_rooms WHERE visibility = true" +
" ORDER BY joined_members DESC" +
" OFFSET $1"
const selectPublicRoomsWithLimitSQL = "" +
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
" FROM publicroomsapi_public_rooms WHERE visibility = true" +
" ORDER BY joined_members DESC" +
" OFFSET $1 LIMIT $2"
const selectPublicRoomsWithFilterSQL = "" +
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
" FROM publicroomsapi_public_rooms" +
" WHERE visibility = true" +
" AND (LOWER(name) LIKE LOWER($1)" +
" OR LOWER(topic) LIKE LOWER($1)" +
" OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" +
" ORDER BY joined_members DESC" +
" OFFSET $2"
const selectPublicRoomsWithLimitAndFilterSQL = "" +
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
" FROM publicroomsapi_public_rooms" +
" WHERE visibility = true" +
" AND (LOWER(name) LIKE LOWER($1)" +
" OR LOWER(topic) LIKE LOWER($1)" +
" OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" +
" ORDER BY joined_members DESC" +
" OFFSET $2 LIMIT $3"
const selectRoomVisibilitySQL = "" +
"SELECT visibility FROM publicroomsapi_public_rooms" +
" WHERE room_id = $1"
const insertNewRoomSQL = "" +
"INSERT INTO publicroomsapi_public_rooms(room_id)" +
" VALUES ($1)"
const incrementJoinedMembersInRoomSQL = "" +
"UPDATE publicroomsapi_public_rooms" +
" SET joined_members = joined_members + 1" +
" WHERE room_id = $1"
const decrementJoinedMembersInRoomSQL = "" +
"UPDATE publicroomsapi_public_rooms" +
" SET joined_members = joined_members - 1" +
" WHERE room_id = $1"
const updateRoomAttributeSQL = "" +
"UPDATE publicroomsapi_public_rooms" +
" SET %s = $1" +
" WHERE room_id = $2"
type publicRoomsStatements struct {
countPublicRoomsStmt *sql.Stmt
selectPublicRoomsStmt *sql.Stmt
selectPublicRoomsWithLimitStmt *sql.Stmt
selectPublicRoomsWithFilterStmt *sql.Stmt
selectPublicRoomsWithLimitAndFilterStmt *sql.Stmt
selectRoomVisibilityStmt *sql.Stmt
insertNewRoomStmt *sql.Stmt
incrementJoinedMembersInRoomStmt *sql.Stmt
decrementJoinedMembersInRoomStmt *sql.Stmt
updateRoomAttributeStmts map[string]*sql.Stmt
}
func (s *publicRoomsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(publicRoomsSchema)
if err != nil {
return
}
stmts := statementList{
{&s.countPublicRoomsStmt, countPublicRoomsSQL},
{&s.selectPublicRoomsStmt, selectPublicRoomsSQL},
{&s.selectPublicRoomsWithLimitStmt, selectPublicRoomsWithLimitSQL},
{&s.selectPublicRoomsWithFilterStmt, selectPublicRoomsWithFilterSQL},
{&s.selectPublicRoomsWithLimitAndFilterStmt, selectPublicRoomsWithLimitAndFilterSQL},
{&s.selectRoomVisibilityStmt, selectRoomVisibilitySQL},
{&s.insertNewRoomStmt, insertNewRoomSQL},
{&s.incrementJoinedMembersInRoomStmt, incrementJoinedMembersInRoomSQL},
{&s.decrementJoinedMembersInRoomStmt, decrementJoinedMembersInRoomSQL},
}
if err = stmts.prepare(db); err != nil {
return
}
s.updateRoomAttributeStmts = make(map[string]*sql.Stmt)
for _, editable := range editableAttributes {
stmt := fmt.Sprintf(updateRoomAttributeSQL, editable)
if s.updateRoomAttributeStmts[editable], err = db.Prepare(stmt); err != nil {
return
}
}
return
}
func (s *publicRoomsStatements) countPublicRooms(ctx context.Context) (nb int64, err error) {
err = s.countPublicRoomsStmt.QueryRowContext(ctx).Scan(&nb)
return
}
func (s *publicRoomsStatements) selectPublicRooms(
ctx context.Context, offset int64, limit int16, filter string,
) ([]types.PublicRoom, error) {
var rows *sql.Rows
var err error
if len(filter) > 0 {
pattern := "%" + filter + "%"
if limit == 0 {
rows, err = s.selectPublicRoomsWithFilterStmt.QueryContext(
ctx, pattern, offset,
)
} else {
rows, err = s.selectPublicRoomsWithLimitAndFilterStmt.QueryContext(
ctx, pattern, offset, limit,
)
}
} else {
if limit == 0 {
rows, err = s.selectPublicRoomsStmt.QueryContext(ctx, offset)
} else {
rows, err = s.selectPublicRoomsWithLimitStmt.QueryContext(
ctx, offset, limit,
)
}
}
if err != nil {
return []types.PublicRoom{}, nil
}
rooms := []types.PublicRoom{}
for rows.Next() {
var r types.PublicRoom
var aliases pq.StringArray
err = rows.Scan(
&r.RoomID, &r.NumJoinedMembers, &aliases, &r.CanonicalAlias,
&r.Name, &r.Topic, &r.WorldReadable, &r.GuestCanJoin, &r.AvatarURL,
)
if err != nil {
return rooms, err
}
r.Aliases = aliases
rooms = append(rooms, r)
}
return rooms, nil
}
func (s *publicRoomsStatements) selectRoomVisibility(
ctx context.Context, roomID string,
) (v bool, err error) {
err = s.selectRoomVisibilityStmt.QueryRowContext(ctx, roomID).Scan(&v)
return
}
func (s *publicRoomsStatements) insertNewRoom(
ctx context.Context, roomID string,
) error {
_, err := s.insertNewRoomStmt.ExecContext(ctx, roomID)
return err
}
func (s *publicRoomsStatements) incrementJoinedMembersInRoom(
ctx context.Context, roomID string,
) error {
_, err := s.incrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
return err
}
func (s *publicRoomsStatements) decrementJoinedMembersInRoom(
ctx context.Context, roomID string,
) error {
_, err := s.decrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
return err
}
func (s *publicRoomsStatements) updateRoomAttribute(
ctx context.Context, attrName string, attrValue attributeValue, roomID string,
) error {
stmt, isEditable := s.updateRoomAttributeStmts[attrName]
if !isEditable {
return errors.New("Cannot edit " + attrName)
}
var value interface{}
switch v := attrValue.(type) {
case []string:
value = pq.StringArray(v)
case bool, string:
value = attrValue
default:
return errors.New("Unsupported attribute type, must be bool, string or []string")
}
_, err := stmt.ExecContext(ctx, value, roomID)
return err
}

View file

@ -0,0 +1,256 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
_ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/publicroomsapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
// PublicRoomsServerDatabase represents a public rooms server database.
type PublicRoomsServerDatabase struct {
db *sql.DB
common.PartitionOffsetStatements
statements publicRoomsStatements
}
type attributeValue interface{}
// NewPublicRoomsServerDatabase creates a new public rooms server database.
func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) {
var db *sql.DB
var err error
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
return nil, err
}
storage := PublicRoomsServerDatabase{
db: db,
}
if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil {
return nil, err
}
if err = storage.statements.prepare(db); err != nil {
return nil, err
}
return &storage, nil
}
// GetRoomVisibility returns the room visibility as a boolean: true if the room
// is publicly visible, false if not.
// Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) GetRoomVisibility(
ctx context.Context, roomID string,
) (bool, error) {
return d.statements.selectRoomVisibility(ctx, roomID)
}
// SetRoomVisibility updates the visibility attribute of a room. This attribute
// must be set to true if the room is publicly visible, false if not.
// Returns an error if the update failed.
func (d *PublicRoomsServerDatabase) SetRoomVisibility(
ctx context.Context, visible bool, roomID string,
) error {
return d.statements.updateRoomAttribute(ctx, "visibility", visible, roomID)
}
// CountPublicRooms returns the number of room set as publicly visible on the server.
// Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) {
return d.statements.countPublicRooms(ctx)
}
// GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number
// of joined members. This array can be limited by a given number of elements, and offset by a given value.
// If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all
// the rooms set as publicly visible on the server.
// Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) GetPublicRooms(
ctx context.Context, offset int64, limit int16, filter string,
) ([]types.PublicRoom, error) {
return d.statements.selectPublicRooms(ctx, offset, limit, filter)
}
// UpdateRoomFromEvents iterate over a slice of state events and call
// UpdateRoomFromEvent on each of them to update the database representation of
// the rooms updated by each event.
// The slice of events to remove is used to update the number of joined members
// for the room in the database.
// If the update triggered by one of the events failed, aborts the process and
// returns an error.
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents(
ctx context.Context,
eventsToAdd []gomatrixserverlib.Event,
eventsToRemove []gomatrixserverlib.Event,
) error {
for _, event := range eventsToAdd {
if err := d.UpdateRoomFromEvent(ctx, event); err != nil {
return err
}
}
for _, event := range eventsToRemove {
if event.Type() == "m.room.member" {
if err := d.updateNumJoinedUsers(ctx, event, true); err != nil {
return err
}
}
}
return nil
}
// UpdateRoomFromEvent updates the database representation of a room from a Matrix event, by
// checking the event's type to know which attribute to change and using the event's content
// to define the new value of the attribute.
// If the event doesn't match with any property used to compute the public room directory,
// does nothing.
// If something went wrong during the process, returns an error.
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(
ctx context.Context, event gomatrixserverlib.Event,
) error {
// Process the event according to its type
switch event.Type() {
case "m.room.create":
return d.statements.insertNewRoom(ctx, event.RoomID())
case "m.room.member":
return d.updateNumJoinedUsers(ctx, event, false)
case "m.room.aliases":
return d.updateRoomAliases(ctx, event)
case "m.room.canonical_alias":
var content common.CanonicalAliasContent
field := &(content.Alias)
attrName := "canonical_alias"
return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.name":
var content common.NameContent
field := &(content.Name)
attrName := "name"
return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.topic":
var content common.TopicContent
field := &(content.Topic)
attrName := "topic"
return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.avatar":
var content common.AvatarContent
field := &(content.URL)
attrName := "avatar_url"
return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.history_visibility":
var content common.HistoryVisibilityContent
field := &(content.HistoryVisibility)
attrName := "world_readable"
strForTrue := "world_readable"
return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
case "m.room.guest_access":
var content common.GuestAccessContent
field := &(content.GuestAccess)
attrName := "guest_can_join"
strForTrue := "can_join"
return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
}
// If the event type didn't match, return with no error
return nil
}
// updateNumJoinedUsers updates the number of joined user in the database representation
// of a room using a given "m.room.member" Matrix event.
// If the membership property of the event isn't "join", ignores it and returs nil.
// If the remove parameter is set to false, increments the joined members counter in the
// database, if set to truem decrements it.
// Returns an error if the update failed.
func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
ctx context.Context, membershipEvent gomatrixserverlib.Event, remove bool,
) error {
membership, err := membershipEvent.Membership()
if err != nil {
return err
}
if membership != gomatrixserverlib.Join {
return nil
}
if remove {
return d.statements.decrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
}
return d.statements.incrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
}
// updateStringAttribute updates a given string attribute in the database
// representation of a room using a given string data field from content of the
// Matrix event triggering the update.
// Returns an error if decoding the Matrix event's content or updating the attribute
// failed.
func (d *PublicRoomsServerDatabase) updateStringAttribute(
ctx context.Context, attrName string, event gomatrixserverlib.Event,
content interface{}, field *string,
) error {
if err := json.Unmarshal(event.Content(), content); err != nil {
return err
}
return d.statements.updateRoomAttribute(ctx, attrName, *field, event.RoomID())
}
// updateBooleanAttribute updates a given boolean attribute in the database
// representation of a room using a given string data field from content of the
// Matrix event triggering the update.
// The attribute is set to true if the field matches a given string, false if not.
// Returns an error if decoding the Matrix event's content or updating the attribute
// failed.
func (d *PublicRoomsServerDatabase) updateBooleanAttribute(
ctx context.Context, attrName string, event gomatrixserverlib.Event,
content interface{}, field *string, strForTrue string,
) error {
if err := json.Unmarshal(event.Content(), content); err != nil {
return err
}
var attrValue bool
if *field == strForTrue {
attrValue = true
} else {
attrValue = false
}
return d.statements.updateRoomAttribute(ctx, attrName, attrValue, event.RoomID())
}
// updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of
// a given room with it.
// Returns an error if decoding the Matrix event or updating the list failed.
func (d *PublicRoomsServerDatabase) updateRoomAliases(
ctx context.Context, aliasesEvent gomatrixserverlib.Event,
) error {
var content common.AliasesContent
if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil {
return err
}
return d.statements.updateRoomAttribute(
ctx, "aliases", content.Aliases, aliasesEvent.RoomID(),
)
}

View file

@ -196,7 +196,12 @@ func processInviteEvent(
return err return err
} }
succeeded := false succeeded := false
defer common.EndTransaction(updater, &succeeded) defer func() {
txerr := common.EndTransaction(updater, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
if updater.IsJoin() { if updater.IsJoin() {
// If the user is joined to the room then that takes precedence over this // If the user is joined to the room then that takes precedence over this

View file

@ -60,7 +60,12 @@ func updateLatestEvents(
return return
} }
succeeded := false succeeded := false
defer common.EndTransaction(updater, &succeeded) defer func() {
txerr := common.EndTransaction(updater, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
u := latestEventsUpdater{ u := latestEventsUpdater{
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID, ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,

View file

@ -102,5 +102,5 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
} }
result.EventNID = types.EventNID(eventNID) result.EventNID = types.EventNID(eventNID)
} }
return results[:i], nil return results[:i], rows.Err()
} }

View file

@ -125,7 +125,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
} }
result[stateKey] = types.EventStateKeyNID(stateKeyNID) result[stateKey] = types.EventStateKeyNID(stateKeyNID)
} }
return result, nil return result, rows.Err()
} }
func (s *eventStateKeyStatements) bulkSelectEventStateKey( func (s *eventStateKeyStatements) bulkSelectEventStateKey(
@ -150,5 +150,5 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
} }
result[types.EventStateKeyNID(stateKeyNID)] = stateKey result[types.EventStateKeyNID(stateKeyNID)] = stateKey
} }
return result, nil return result, rows.Err()
} }

View file

@ -143,5 +143,5 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
} }
result[eventType] = types.EventTypeNID(eventTypeNID) result[eventType] = types.EventTypeNID(eventTypeNID)
} }
return result, nil return result, rows.Err()
} }

View file

@ -209,6 +209,9 @@ func (s *eventStatements) bulkSelectStateEventByID(
return nil, err return nil, err
} }
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventIDs) { if i != len(eventIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query. // We don't know which ones were missing because we don't return the string IDs in the query.
@ -219,7 +222,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
) )
} }
return results, err return results, nil
} }
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
@ -251,12 +254,15 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
) )
} }
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventIDs) { if i != len(eventIDs) {
return nil, types.MissingEventError( return nil, types.MissingEventError(
fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)), fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
) )
} }
return results, err return results, nil
} }
func (s *eventStatements) updateEventState( func (s *eventStatements) updateEventState(
@ -321,6 +327,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
result.EventID = eventID result.EventID = eventID
result.EventSHA256 = eventSHA256 result.EventSHA256 = eventSHA256
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) { if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
} }
@ -343,6 +352,9 @@ func (s *eventStatements) bulkSelectEventReference(
return nil, err return nil, err
} }
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) { if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
} }
@ -366,6 +378,9 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ
} }
results[types.EventNID(eventNID)] = eventID results[types.EventNID(eventNID)] = eventID
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) { if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
} }
@ -389,7 +404,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str
} }
results[eventID] = types.EventNID(eventNID) results[eventID] = types.EventNID(eventNID)
} }
return results, nil return results, rows.Err()
} }
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {

View file

@ -114,21 +114,23 @@ func (s *inviteStatements) insertInviteEvent(
func (s *inviteStatements) updateInviteRetired( func (s *inviteStatements) updateInviteRetired(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) ([]string, error) {
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer (func() { err = rows.Close() })() defer rows.Close() // nolint: errcheck
var eventIDs []string
for rows.Next() { for rows.Next() {
var inviteEventID string var inviteEventID string
if err := rows.Scan(&inviteEventID); err != nil { if err = rows.Scan(&inviteEventID); err != nil {
return nil, err return nil, err
} }
eventIDs = append(eventIDs, inviteEventID) eventIDs = append(eventIDs, inviteEventID)
} }
return return eventIDs, rows.Err()
} }
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs // selectInviteActiveForUserInRoom returns a list of sender state key NIDs
@ -151,5 +153,5 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom(
} }
result = append(result, types.EventStateKeyNID(senderUserNID)) result = append(result, types.EventStateKeyNID(senderUserNID))
} }
return result, nil return result, rows.Err()
} }

View file

@ -151,6 +151,7 @@ func (s *membershipStatements) selectMembershipsFromRoom(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
for rows.Next() { for rows.Next() {
var eNID types.EventNID var eNID types.EventNID
@ -159,8 +160,9 @@ func (s *membershipStatements) selectMembershipsFromRoom(
} }
eventNIDs = append(eventNIDs, eNID) eventNIDs = append(eventNIDs, eNID)
} }
return return eventNIDs, rows.Err()
} }
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, membership membershipState, roomNID types.RoomNID, membership membershipState,
@ -170,6 +172,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
for rows.Next() { for rows.Next() {
var eNID types.EventNID var eNID types.EventNID
@ -178,7 +181,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
} }
eventNIDs = append(eventNIDs, eNID) eventNIDs = append(eventNIDs, eNID)
} }
return return eventNIDs, rows.Err()
} }
func (s *membershipStatements) updateMembership( func (s *membershipStatements) updateMembership(

View file

@ -90,23 +90,23 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias(
func (s *roomAliasesStatements) selectAliasesFromRoomID( func (s *roomAliasesStatements) selectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (aliases []string, err error) { ) ([]string, error) {
aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return nil, err
} }
defer rows.Close() // nolint: errcheck
var aliases []string
for rows.Next() { for rows.Next() {
var alias string var alias string
if err = rows.Scan(&alias); err != nil { if err = rows.Scan(&alias); err != nil {
return return nil, err
} }
aliases = append(aliases, alias) aliases = append(aliases, alias)
} }
return aliases, rows.Err()
return
} }
func (s *roomAliasesStatements) selectCreatorIDFromAlias( func (s *roomAliasesStatements) selectCreatorIDFromAlias(

View file

@ -152,7 +152,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
eventNID int64 eventNID int64
entry types.StateEntry entry types.StateEntry
) )
if err := rows.Scan( if err = rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil { ); err != nil {
return nil, err return nil, err
@ -169,10 +169,13 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
} }
current.StateEntries = append(current.StateEntries, entry) current.StateEntries = append(current.StateEntries, entry)
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateBlockNIDs) { if i != len(stateBlockNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
} }
return results, nil return results, err
} }
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
@ -237,7 +240,7 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
if current.StateEntries != nil { if current.StateEntries != nil {
results = append(results, current) results = append(results, current)
} }
return results, nil return results, rows.Err()
} }
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {

View file

@ -104,7 +104,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
result := &results[i] result := &results[i]
var stateBlockNIDs pq.Int64Array var stateBlockNIDs pq.Int64Array
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err return nil, err
} }
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs)) result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
@ -112,6 +112,9 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k]) result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
} }
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateNIDs) { if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
} }

Some files were not shown because too many files have changed in this diff Show more