Merge branch 'master' of https://github.com/matrix-org/dendrite into basicauth-metrics

This commit is contained in:
Till Faelligen 2020-02-13 20:16:34 +01:00
commit 2ccc149836
142 changed files with 9876 additions and 856 deletions

View file

@ -140,7 +140,7 @@ func RetrieveUserProfile(
ctx context.Context,
userID string,
asAPI AppServiceQueryAPI,
accountDB *accounts.Database,
accountDB accounts.Database,
) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {

View file

@ -41,8 +41,8 @@ import (
// component.
func SetupAppServiceAPIComponent(
base *basecomponent.BaseDendrite,
accountsDB *accounts.Database,
deviceDB *devices.Database,
accountsDB accounts.Database,
deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient,
roomserverAliasAPI roomserverAPI.RoomserverAliasAPI,
roomserverQueryAPI roomserverAPI.RoomserverQueryAPI,
@ -100,7 +100,7 @@ func SetupAppServiceAPIComponent(
// Set up HTTP Endpoints
routing.Setup(
base.APIMux, *base.Cfg, roomserverQueryAPI, roomserverAliasAPI,
base.APIMux, base.Cfg, roomserverQueryAPI, roomserverAliasAPI,
accountsDB, federation, transactionsCache,
)
@ -111,8 +111,8 @@ func SetupAppServiceAPIComponent(
// `sender_localpart` field of each application service if it doesn't
// exist already
func generateAppServiceAccount(
accountsDB *accounts.Database,
deviceDB *devices.Database,
accountsDB accounts.Database,
deviceDB devices.Database,
as config.ApplicationService,
) error {
ctx := context.Background()

View file

@ -33,7 +33,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer
db *accounts.Database
db accounts.Database
asDB *storage.Database
query api.RoomserverQueryAPI
alias api.RoomserverAliasAPI
@ -46,7 +46,7 @@ type OutputRoomEventConsumer struct {
func NewOutputRoomEventConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
store *accounts.Database,
store accounts.Database,
appserviceDB *storage.Database,
queryAPI api.RoomserverQueryAPI,
aliasAPI api.RoomserverAliasAPI,

View file

@ -36,9 +36,9 @@ const pathPrefixApp = "/_matrix/app/v1"
// applied:
// nolint: gocyclo
func Setup(
apiMux *mux.Router, cfg config.Dendrite, // nolint: unparam
apiMux *mux.Router, cfg *config.Dendrite, // nolint: unparam
queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam
accountDB *accounts.Database, // nolint: unparam
accountDB accounts.Database, // nolint: unparam
federation *gomatrixserverlib.FederationClient, // nolint: unparam
transactionsCache *transactions.Cache, // nolint: unparam
) {

View file

@ -0,0 +1,141 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
-- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
content TEXT NOT NULL,
PRIMARY KEY(localpart, room_id, type)
);
`
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(accountDataSchema)
if err != nil {
return
}
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
return
}
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
return
}
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
return
}
return
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) (err error) {
stmt := s.insertAccountDataStmt
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return
}
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
return
}
defer rows.Close() // nolint: errcheck
global = []gomatrixserverlib.ClientEvent{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
for rows.Next() {
var roomID string
var dataType string
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return
}
ac := gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
} else {
global = append(global, ac)
}
}
return global, rooms, rows.Err()
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
stmt := s.selectAccountDataByTypeStmt
var content []byte
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return
}
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
return
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package postgres
import (
"context"

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package postgres
import (
"context"

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package postgres
import (
"context"
@ -122,11 +122,10 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
for rows.Next() {
var m authtypes.Membership
m.Localpart = localpart
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil {
return nil, err
if err = rows.Scan(&m.RoomID, &m.EventID); err != nil {
return
}
memberships = append(memberships, m)
}
return
return memberships, rows.Err()
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package postgres
import (
"context"

View file

@ -0,0 +1,392 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
type Database struct {
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
accountDatas accountDataStatements
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package postgres
import (
"context"

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package sqlite3
import (
"context"
@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS account_data (
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
`
const selectAccountDataSQL = "" +

View file

@ -0,0 +1,151 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
password_hash TEXT,
-- Identifies which application service this account belongs to, if any.
appservice_id TEXT
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
);
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(accountsSchema)
if err != nil {
return
}
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash, appserviceID string,
) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if appserviceID == "" {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
if err != nil {
return nil, err
}
return &authtypes.Account{
Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
AppServiceID: appserviceID,
}, nil
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Account, error) {
var appserviceIDPtr sql.NullString
var acc authtypes.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
}
return nil, err
}
if appserviceIDPtr.Valid {
acc.AppServiceID = appserviceIDPtr.String
}
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
return &acc, nil
}
func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context,
) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
return
}

View file

@ -0,0 +1,139 @@
// Copyright 2017 Jan Christian Grünhage
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
)
const filterSchema = `
-- Stores data about filters
CREATE TABLE IF NOT EXISTS account_filter (
-- The filter
filter TEXT NOT NULL,
-- The ID
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The localpart of the Matrix user ID associated to this filter
localpart TEXT NOT NULL,
UNIQUE (id, localpart)
);
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
`
const selectFilterSQL = "" +
"SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2"
const selectFilterIDByContentSQL = "" +
"SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2"
const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
const selectLastInsertedFilterIDSQL = "" +
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
type filterStatements struct {
selectFilterStmt *sql.Stmt
selectLastInsertedFilterIDStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
}
func (s *filterStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(filterSchema)
if err != nil {
return
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return
}
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
return
}
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return
}
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return
}
return
}
func (s *filterStatements) selectFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
// Retrieve filter from database (stored as canonical JSON)
var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil {
return nil, err
}
// Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
}
return &filter, nil
}
func (s *filterStatements) insertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) {
var existingFilterID string
// Serialise json
filterJSON, err := json.Marshal(filter)
if err != nil {
return "", err
}
// Remove whitespaces and sort JSON data
// needed to prevent from inserting the same filter multiple times
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
if err != nil {
return "", err
}
// Check if filter already exists in the database using its localpart and content
//
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
return "", err
}
// If it does, return the existing ID
if existingFilterID != "" {
return existingFilterID, err
}
// Otherwise insert the filter and return the new ID
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
return "", err
}
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
if err := row.Scan(&filterID); err != nil {
return "", err
}
return
}

View file

@ -0,0 +1,131 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const membershipSchema = `
-- Stores data about users memberships to rooms.
CREATE TABLE IF NOT EXISTS account_memberships (
-- The Matrix user ID localpart for the member
localpart TEXT NOT NULL,
-- The room this user is a member of
room_id TEXT NOT NULL,
-- The ID of the join membership event
event_id TEXT NOT NULL,
-- A user can only be member of a room once
PRIMARY KEY (localpart, room_id),
UNIQUE (event_id)
);
`
const insertMembershipSQL = `
INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)
ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id
`
const selectMembershipsByLocalpartSQL = "" +
"SELECT room_id, event_id FROM account_memberships WHERE localpart = $1"
const selectMembershipInRoomByLocalpartSQL = "" +
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id IN ($1)"
type membershipStatements struct {
deleteMembershipsByEventIDsStmt *sql.Stmt
insertMembershipStmt *sql.Stmt
selectMembershipInRoomByLocalpartStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(membershipSchema)
if err != nil {
return
}
if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil {
return
}
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
return
}
if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil {
return
}
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return
}
return
}
func (s *membershipStatements) insertMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) (err error) {
stmt := txn.Stmt(s.insertMembershipStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
return
}
func (s *membershipStatements) deleteMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (err error) {
stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt)
_, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs))
return
}
func (s *membershipStatements) selectMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
membership := authtypes.Membership{Localpart: localpart, RoomID: roomID}
stmt := s.selectMembershipInRoomByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID)
return membership, err
}
func (s *membershipStatements) selectMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
stmt := s.selectMembershipsByLocalpartStmt
rows, err := stmt.QueryContext(ctx, localpart)
if err != nil {
return
}
memberships = []authtypes.Membership{}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var m authtypes.Membership
m.Localpart = localpart
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil {
return nil, err
}
memberships = append(memberships, m)
}
return
}

View file

@ -0,0 +1,107 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
`
const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
type profilesStatements struct {
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(profilesSchema)
if err != nil {
return
}
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
return
}
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
return
}
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
return
}
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
return
}
return
}
func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string,
) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
return
}
func (s *profilesStatements) selectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
}
return &profile, nil
}
func (s *profilesStatements) setAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
}
func (s *profilesStatements) setDisplayName(
ctx context.Context, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
}

View file

@ -0,0 +1,392 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/mattn/go-sqlite3"
)
// Database represents an account database
type Database struct {
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
accountDatas accountDataStatements
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}

View file

@ -0,0 +1,129 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const threepidSchema = `
-- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid (
-- The third party identifier
threepid TEXT NOT NULL,
-- The 3PID medium
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(threepidSchema)
if err != nil {
return
}
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
return
}
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
return
}
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
return
}
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
return
}
return
}
func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *threepidStatements) selectThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return
}
defer rows.Close() // nolint: errcheck
threepids = []authtypes.ThreePID{}
for rows.Next() {
var threepid string
var medium string
if err = rows.Scan(&threepid, &medium); err != nil {
return
}
threepids = append(threepids, authtypes.ThreePID{
Address: threepid,
Medium: medium,
})
}
return threepids, rows.Err()
}
func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := common.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
return
}
func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
return
}

View file

@ -1,392 +1,56 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
import (
"context"
"database/sql"
"errors"
"net/url"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
type Database struct {
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
accountDatas accountDataStatements
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
type Database interface {
common.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error)
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
return nil, err
}
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
return nil, err
}
f := filterStatements{}
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, err
return postgres.NewDatabase(dataSourceName, serverName)
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
switch uri.Scheme {
case "postgres":
return postgres.NewDatabase(dataSourceName, serverName)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return postgres.NewDatabase(dataSourceName, serverName)
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the membership event.
// If a membership already exists between the user and the room, or if the
// insert fails, returns the SQL error
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships corresponding to the
// `join` membership events IDs in the eventIDs slice.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
return nil
})
}
// GetMembershipInRoomByLocalpart returns the membership for an user
// matching the given localpart if he is a member of the room matching roomID,
// if not sql.ErrNoRows is returned.
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipInRoomByLocalpart(
ctx context.Context, localpart, roomID string,
) (authtypes.Membership, error) {
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
}
return nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*authtypes.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}

View file

@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/common"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
@ -80,6 +80,9 @@ const deleteDeviceSQL = "" +
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@ -88,6 +91,7 @@ type devicesStatements struct {
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
deleteDevicesStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
@ -117,6 +121,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
s.serverName = server
return
}
@ -142,6 +149,7 @@ func (s *devicesStatements) insertDevice(
}, nil
}
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
@ -150,6 +158,18 @@ func (s *devicesStatements) deleteDevice(
return err
}
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
stmt := common.TxStmt(txn, s.deleteDevicesStmt)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
return err
}
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
@ -206,6 +226,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
if err != nil {
return devices, err
}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var dev authtypes.Device
@ -217,5 +238,5 @@ func (s *devicesStatements) selectDevicesByLocalpart(
devices = append(devices, dev)
}
return devices, nil
return devices, rows.Err()
}

View file

@ -0,0 +1,182 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -0,0 +1,243 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
)
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
UNIQUE (localpart, device_id)
);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
" VALUES ($1, $2, $3, $4, $5, $6)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
type devicesStatements struct {
db *sql.DB
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(devicesSchema)
if err != nil {
return
}
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
countStmt := common.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := common.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &authtypes.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
SessionID: sessionID,
}, nil
}
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := common.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
}
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := common.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
}
func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*authtypes.Device, error) {
var dev authtypes.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken
}
return &dev, err
}
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
var dev authtypes.Device
var created sql.NullInt64
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
}
return &dev, err
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
devices := []authtypes.Device{}
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return devices, err
}
for rows.Next() {
var dev authtypes.Device
err = rows.Scan(&dev.ID)
if err != nil {
return devices, err
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, nil
}

View file

@ -0,0 +1,184 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -1,167 +1,37 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"net/url"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
RemoveAllDevices(ctx context.Context, localpart string) error
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return "", err
return postgres.NewDatabase(dataSourceName, serverName)
}
switch uri.Scheme {
case "postgres":
return postgres.NewDatabase(dataSourceName, serverName)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return postgres.NewDatabase(dataSourceName, serverName)
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -34,8 +34,8 @@ import (
// component.
func SetupClientAPIComponent(
base *basecomponent.BaseDendrite,
deviceDB *devices.Database,
accountsDB *accounts.Database,
deviceDB devices.Database,
accountsDB accounts.Database,
federation *gomatrixserverlib.FederationClient,
keyRing *gomatrixserverlib.KeyRing,
aliasAPI roomserverAPI.RoomserverAliasAPI,
@ -67,7 +67,7 @@ func SetupClientAPIComponent(
}
routing.Setup(
base.APIMux, *base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI,
base.APIMux, base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI,
accountsDB, deviceDB, federation, *keyRing, userUpdateProducer,
syncProducer, typingProducer, transactionsCache, fedSenderAPI,
)

View file

@ -31,7 +31,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer
db *accounts.Database
db accounts.Database
query api.RoomserverQueryAPI
serverName string
}
@ -40,7 +40,7 @@ type OutputRoomEventConsumer struct {
func NewOutputRoomEventConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
store *accounts.Database,
store accounts.Database,
queryAPI api.RoomserverQueryAPI,
) *OutputRoomEventConsumer {

View file

@ -30,7 +30,7 @@ import (
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
func GetAccountData(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string,
) util.JSONResponse {
if userID != device.UserID {
@ -62,7 +62,7 @@ func GetAccountData(
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
func SaveAccountData(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if userID != device.UserID {

View file

@ -102,7 +102,7 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
// AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID}
func AuthFallback(
w http.ResponseWriter, req *http.Request, authType string,
cfg config.Dendrite,
cfg *config.Dendrite,
) *util.JSONResponse {
sessionID := req.URL.Query().Get("session")
@ -130,7 +130,7 @@ func AuthFallback(
if req.Method == http.MethodGet {
// Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil {
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
return err
}
@ -144,7 +144,7 @@ func AuthFallback(
} else if req.Method == http.MethodPost {
// Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil {
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
return err
}
@ -156,7 +156,7 @@ func AuthFallback(
}
response := req.Form.Get("g-recaptcha-response")
if err := validateRecaptcha(&cfg, response, clientIP); err != nil {
if err := validateRecaptcha(cfg, response, clientIP); err != nil {
util.GetLogger(req.Context()).Error(err)
return err
}

View file

@ -134,8 +134,8 @@ type fledglingEvent struct {
// CreateRoom implements /createRoom
func CreateRoom(
req *http.Request, device *authtypes.Device,
cfg config.Dendrite, producer *producers.RoomserverProducer,
accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
cfg *config.Dendrite, producer *producers.RoomserverProducer,
accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
@ -148,8 +148,8 @@ func CreateRoom(
// nolint: gocyclo
func createRoom(
req *http.Request, device *authtypes.Device,
cfg config.Dendrite, roomID string, producer *producers.RoomserverProducer,
accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
cfg *config.Dendrite, roomID string, producer *producers.RoomserverProducer,
accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
@ -344,7 +344,7 @@ func createRoom(
func buildEvent(
builder *gomatrixserverlib.EventBuilder,
provider gomatrixserverlib.AuthEventProvider,
cfg config.Dendrite,
cfg *config.Dendrite,
evTime time.Time,
) (*gomatrixserverlib.Event, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)

View file

@ -40,9 +40,13 @@ type deviceUpdateJSON struct {
DisplayName *string `json:"display_name"`
}
type devicesDeleteJSON struct {
Devices []string `json:"devices"`
}
// GetDeviceByID handles /devices/{deviceID}
func GetDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
@ -72,7 +76,7 @@ func GetDeviceByID(
// GetDevicesByLocalpart handles /devices
func GetDevicesByLocalpart(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@ -103,7 +107,7 @@ func GetDevicesByLocalpart(
// UpdateDeviceByID handles PUT on /devices/{deviceID}
func UpdateDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
@ -146,3 +150,54 @@ func UpdateDeviceByID(
JSON: struct{}{},
}
}
// DeleteDeviceById handles DELETE requests to /devices/{deviceId}
func DeleteDeviceById(
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return httputil.LogThenError(req, err)
}
ctx := req.Context()
defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
// DeleteDevices handles POST requests to /delete_devices
func DeleteDevices(
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return httputil.LogThenError(req, err)
}
ctx := req.Context()
payload := devicesDeleteJSON{}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
return httputil.LogThenError(req, err)
}
defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -27,7 +27,7 @@ import (
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
func GetFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string,
req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, filterID string,
) util.JSONResponse {
if userID != device.UserID {
return util.JSONResponse{
@ -63,7 +63,7 @@ type filterResponse struct {
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
func PutFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string,
req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string,
) util.JSONResponse {
if userID != device.UserID {
return util.JSONResponse{

View file

@ -31,7 +31,7 @@ type getEventRequest struct {
device *authtypes.Device
roomID string
eventID string
cfg config.Dendrite
cfg *config.Dendrite
federation *gomatrixserverlib.FederationClient
keyRing gomatrixserverlib.KeyRing
requestedEvent gomatrixserverlib.Event
@ -44,7 +44,7 @@ func GetEvent(
device *authtypes.Device,
roomID string,
eventID string,
cfg config.Dendrite,
cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI,
federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.KeyRing,

View file

@ -39,13 +39,13 @@ func JoinRoomByIDOrAlias(
req *http.Request,
device *authtypes.Device,
roomIDOrAlias string,
cfg config.Dendrite,
cfg *config.Dendrite,
federation *gomatrixserverlib.FederationClient,
producer *producers.RoomserverProducer,
queryAPI roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI,
keyRing gomatrixserverlib.KeyRing,
accountDB *accounts.Database,
accountDB accounts.Database,
) util.JSONResponse {
var content map[string]interface{} // must be a JSON object
if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil {
@ -98,7 +98,7 @@ type joinRoomReq struct {
evTime time.Time
content map[string]interface{}
userID string
cfg config.Dendrite
cfg *config.Dendrite
federation *gomatrixserverlib.FederationClient
producer *producers.RoomserverProducer
queryAPI roomserverAPI.RoomserverQueryAPI

View file

@ -70,8 +70,8 @@ func passwordLogin() loginFlows {
// Login implements GET and POST /login
func Login(
req *http.Request, accountDB *accounts.Database, deviceDB *devices.Database,
cfg config.Dendrite,
req *http.Request, accountDB accounts.Database, deviceDB devices.Database,
cfg *config.Dendrite,
) util.JSONResponse {
if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options
return util.JSONResponse{
@ -153,7 +153,7 @@ func Login(
func getDevice(
ctx context.Context,
r passwordRequest,
deviceDB *devices.Database,
deviceDB devices.Database,
acc *authtypes.Account,
token string,
) (dev *authtypes.Device, err error) {

View file

@ -26,7 +26,7 @@ import (
// Logout handles POST /logout
func Logout(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@ -45,7 +45,7 @@ func Logout(
// LogoutAll handles POST /logout/all
func LogoutAll(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {

View file

@ -40,8 +40,8 @@ var errMissingUserID = errors.New("'user_id' must be supplied")
// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
// by building a m.room.member event then sending it to the room server
func SendMembership(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
roomID string, membership string, cfg config.Dendrite,
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
roomID string, membership string, cfg *config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
producer *producers.RoomserverProducer,
) util.JSONResponse {
@ -116,10 +116,10 @@ func SendMembership(
func buildMembershipEvent(
ctx context.Context,
body threepid.MembershipRequest, accountDB *accounts.Database,
body threepid.MembershipRequest, accountDB accounts.Database,
device *authtypes.Device,
membership, roomID string,
cfg config.Dendrite, evTime time.Time,
cfg *config.Dendrite, evTime time.Time,
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*gomatrixserverlib.Event, error) {
stateKey, reason, err := getMembershipStateKey(body, device, membership)
@ -165,8 +165,8 @@ func buildMembershipEvent(
func loadProfile(
ctx context.Context,
userID string,
cfg config.Dendrite,
accountDB *accounts.Database,
cfg *config.Dendrite,
accountDB accounts.Database,
asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -214,9 +214,9 @@ func checkAndProcessThreepid(
req *http.Request,
device *authtypes.Device,
body *threepid.MembershipRequest,
cfg config.Dendrite,
cfg *config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI,
accountDB *accounts.Database,
accountDB accounts.Database,
producer *producers.RoomserverProducer,
membership, roomID string,
evTime time.Time,

View file

@ -33,7 +33,7 @@ type response struct {
// GetMemberships implements GET /rooms/{roomId}/members
func GetMemberships(
req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool,
_ config.Dendrite,
_ *config.Dendrite,
queryAPI api.RoomserverQueryAPI,
) util.JSONResponse {
queryReq := api.QueryMembershipsForRoomRequest{

View file

@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID}
func GetProfile(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
@ -64,7 +64,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@ -90,7 +90,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse {
@ -170,7 +170,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName(
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@ -196,7 +196,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse {
@ -279,7 +279,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically
// common.ErrProfileNoExists when the profile doesn't exist.
func getProfile(
ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite,
ctx context.Context, accountDB accounts.Database, cfg *config.Dendrite,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
@ -343,7 +343,7 @@ func buildMembershipEvents(
return nil, err
}
event, err := common.BuildEvent(ctx, &builder, *cfg, evTime, queryAPI, nil)
event, err := common.BuildEvent(ctx, &builder, cfg, evTime, queryAPI, nil)
if err != nil {
return nil, err
}

View file

@ -440,8 +440,8 @@ func validateApplicationService(
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
func Register(
req *http.Request,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
cfg *config.Dendrite,
) util.JSONResponse {
@ -513,8 +513,8 @@ func handleGuestRegistration(
req *http.Request,
r registerRequest,
cfg *config.Dendrite,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
) util.JSONResponse {
//Generate numeric local part for guest user
@ -570,8 +570,8 @@ func handleRegistrationFlow(
r registerRequest,
sessionID string,
cfg *config.Dendrite,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
) util.JSONResponse {
// TODO: Shared secret registration (create new user scripts)
// TODO: Enable registration config flag
@ -668,8 +668,8 @@ func handleApplicationServiceRegistration(
req *http.Request,
r registerRequest,
cfg *config.Dendrite,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
) util.JSONResponse {
// Check if we previously had issues extracting the access token from the
// request.
@ -707,8 +707,8 @@ func checkAndCompleteFlow(
r registerRequest,
sessionID string,
cfg *config.Dendrite,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue
@ -730,8 +730,8 @@ func checkAndCompleteFlow(
// LegacyRegister process register requests from the legacy v1 API
func LegacyRegister(
req *http.Request,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
cfg *config.Dendrite,
) util.JSONResponse {
var r legacyRegisterRequest
@ -814,8 +814,8 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
// not all
func completeRegistration(
ctx context.Context,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
username, password, appserviceID string,
inhibitLogin common.WeakBoolean,
displayName, deviceID *string,
@ -991,8 +991,8 @@ type availableResponse struct {
// RegisterAvailable checks if the username is already taken or invalid.
func RegisterAvailable(
req *http.Request,
cfg config.Dendrite,
accountDB *accounts.Database,
cfg *config.Dendrite,
accountDB accounts.Database,
) util.JSONResponse {
username := req.URL.Query().Get("username")

View file

@ -40,7 +40,7 @@ func newTag() gomatrix.TagContent {
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags(
req *http.Request,
accountDB *accounts.Database,
accountDB accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
@ -77,7 +77,7 @@ func GetTags(
// the tag to the "map" and saving the new "map" to the DB
func PutTag(
req *http.Request,
accountDB *accounts.Database,
accountDB accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
@ -134,7 +134,7 @@ func PutTag(
// the "map" and then saving the new "map" in the DB
func DeleteTag(
req *http.Request,
accountDB *accounts.Database,
accountDB accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
@ -203,7 +203,7 @@ func obtainSavedTags(
req *http.Request,
userID string,
roomID string,
accountDB *accounts.Database,
accountDB accounts.Database,
) (string, *gomatrixserverlib.ClientEvent, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
@ -222,7 +222,7 @@ func saveTagData(
req *http.Request,
localpart string,
roomID string,
accountDB *accounts.Database,
accountDB accounts.Database,
Tag gomatrix.TagContent,
) error {
newTagData, err := json.Marshal(Tag)

View file

@ -47,13 +47,13 @@ const pathPrefixUnstable = "/_matrix/client/unstable"
// applied:
// nolint: gocyclo
func Setup(
apiMux *mux.Router, cfg config.Dendrite,
apiMux *mux.Router, cfg *config.Dendrite,
producer *producers.RoomserverProducer,
queryAPI roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.KeyRing,
userUpdateProducer *producers.UserUpdateProducer,
@ -170,11 +170,11 @@ func Setup(
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
return Register(req, accountDB, deviceDB, &cfg)
return Register(req, accountDB, deviceDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions)
v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
return LegacyRegister(req, accountDB, deviceDB, &cfg)
return LegacyRegister(req, accountDB, deviceDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
@ -187,7 +187,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return DirectoryRoom(req, vars["roomAlias"], federation, &cfg, aliasAPI, federationSender)
return DirectoryRoom(req, vars["roomAlias"], federation, cfg, aliasAPI, federationSender)
}),
).Methods(http.MethodGet, http.MethodOptions)
@ -197,7 +197,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return SetLocalAlias(req, device, vars["roomAlias"], &cfg, aliasAPI)
return SetLocalAlias(req, device, vars["roomAlias"], cfg, aliasAPI)
}),
).Methods(http.MethodPut, http.MethodOptions)
@ -301,7 +301,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return GetProfile(req, accountDB, &cfg, vars["userID"], asAPI, federation)
return GetProfile(req, accountDB, cfg, vars["userID"], asAPI, federation)
}),
).Methods(http.MethodGet, http.MethodOptions)
@ -311,7 +311,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return GetAvatarURL(req, accountDB, &cfg, vars["userID"], asAPI, federation)
return GetAvatarURL(req, accountDB, cfg, vars["userID"], asAPI, federation)
}),
).Methods(http.MethodGet, http.MethodOptions)
@ -321,7 +321,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI)
return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI)
}),
).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
@ -333,7 +333,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return GetDisplayName(req, accountDB, &cfg, vars["userID"], asAPI, federation)
return GetDisplayName(req, accountDB, cfg, vars["userID"], asAPI, federation)
}),
).Methods(http.MethodGet, http.MethodOptions)
@ -343,7 +343,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI)
return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI)
}),
).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
@ -503,6 +503,22 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
common.MakeAuthAPI("delete_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return DeleteDeviceById(req, deviceDB, device, vars["deviceID"])
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/delete_devices",
common.MakeAuthAPI("delete_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return DeleteDevices(req, deviceDB, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest
r0mux.Handle("/events",
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {

View file

@ -43,7 +43,7 @@ func SendEvent(
req *http.Request,
device *authtypes.Device,
roomID, eventType string, txnID, stateKey *string,
cfg config.Dendrite,
cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
txnCache *transactions.Cache,
@ -93,7 +93,7 @@ func generateSendEvent(
req *http.Request,
device *authtypes.Device,
roomID, eventType string, stateKey *string,
cfg config.Dendrite,
cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI,
) (*gomatrixserverlib.Event, *util.JSONResponse) {
// parse the incoming http request

View file

@ -34,7 +34,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer
func SendTyping(
req *http.Request, device *authtypes.Device, roomID string,
userID string, accountDB *accounts.Database,
userID string, accountDB accounts.Database,
typingProducer *producers.TypingServerProducer,
) util.JSONResponse {
if device.UserID != userID {

View file

@ -39,7 +39,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements:
// POST /account/3pid/email/requestToken
// POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg config.Dendrite) util.JSONResponse {
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.Dendrite) util.JSONResponse {
var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr
@ -82,8 +82,8 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
// CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
cfg config.Dendrite,
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
cfg *config.Dendrite,
) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
@ -142,7 +142,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@ -161,7 +161,7 @@ func GetAssociated3PIDs(
}
// Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONResponse {
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr

View file

@ -31,7 +31,7 @@ import (
// RequestTurnServer implements:
// GET /voip/turnServer
func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.Dendrite) util.JSONResponse {
func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg *config.Dendrite) util.JSONResponse {
turnConfig := cfg.TURN
// TODO Guest Support

View file

@ -86,8 +86,8 @@ var (
// can be emitted.
func CheckAndProcessInvite(
ctx context.Context,
device *authtypes.Device, body *MembershipRequest, cfg config.Dendrite,
queryAPI api.RoomserverQueryAPI, db *accounts.Database,
device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, db accounts.Database,
producer *producers.RoomserverProducer, membership string, roomID string,
evTime time.Time,
) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed.
func queryIDServer(
ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite(
ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name
@ -330,7 +330,7 @@ func checkIDServerSignatures(
func emit3PIDInviteEvent(
ctx context.Context,
body *MembershipRequest, res *idServerStoreInviteResponse,
device *authtypes.Device, roomID string, cfg config.Dendrite,
device *authtypes.Device, roomID string, cfg *config.Dendrite,
queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer,
evTime time.Time,
) error {

View file

@ -53,7 +53,7 @@ type Credentials struct {
// Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status.
func CreateSession(
ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite,
ctx context.Context, req EmailAssociationRequest, cfg *config.Dendrite,
) (string, error) {
if err := isTrusted(req.IDServer, cfg); err != nil {
return "", err
@ -101,7 +101,7 @@ func CreateSession(
// Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status.
func CheckAssociation(
ctx context.Context, creds Credentials, cfg config.Dendrite,
ctx context.Context, creds Credentials, cfg *config.Dendrite,
) (bool, string, string, error) {
if err := isTrusted(creds.IDServer, cfg); err != nil {
return false, "", "", err
@ -142,7 +142,7 @@ func CheckAssociation(
// identifier and a Matrix ID.
// Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status.
func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) error {
func PublishAssociation(creds Credentials, userID string, cfg *config.Dendrite) error {
if err := isTrusted(creds.IDServer, cfg); err != nil {
return err
}
@ -177,7 +177,7 @@ func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) e
// isTrusted checks if a given identity server is part of the list of trusted
// identity servers in the configuration file.
// Returns an error if the server isn't trusted.
func isTrusted(idServer string, cfg config.Dendrite) error {
func isTrusted(idServer string, cfg *config.Dendrite) error {
for _, server := range cfg.Matrix.TrustedIDServers {
if idServer == server {
return nil

View file

@ -21,7 +21,7 @@ import (
"os"
"strings"
"github.com/Shopify/sarama"
sarama "gopkg.in/Shopify/sarama.v1"
)
const usage = `Usage: %s

View file

@ -18,6 +18,7 @@ import (
"database/sql"
"io"
"net/http"
"net/url"
"golang.org/x/crypto/ed25519"
@ -68,7 +69,13 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite {
logrus.WithError(err).Panicf("failed to start opentracing")
}
kafkaConsumer, kafkaProducer := setupKafka(cfg)
var kafkaConsumer sarama.Consumer
var kafkaProducer sarama.SyncProducer
if cfg.Kafka.UseNaffka {
kafkaConsumer, kafkaProducer = setupNaffka(cfg)
} else {
kafkaConsumer, kafkaProducer = setupKafka(cfg)
}
return &BaseDendrite{
componentName: componentName,
@ -118,7 +125,7 @@ func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.Fede
// CreateDeviceDB creates a new instance of the device database. Should only be
// called once per component.
func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
func (b *BaseDendrite) CreateDeviceDB() devices.Database {
db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to devices db")
@ -129,7 +136,7 @@ func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
func (b *BaseDendrite) CreateAccountsDB() *accounts.Database {
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db")
@ -186,28 +193,8 @@ func (b *BaseDendrite) SetupAndServeHTTP(bindaddr string, listenaddr string) {
logrus.Infof("Stopped %s server on %s", b.componentName, addr)
}
// setupKafka creates kafka consumer/producer pair from the config. Checks if
// should use naffka.
// setupKafka creates kafka consumer/producer pair from the config.
func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
if cfg.Kafka.UseNaffka {
db, err := sql.Open("postgres", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err := naffka.NewPostgresqlDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
naff, err := naffka.New(naffkaDB)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka")
}
return naff, naff
}
consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
if err != nil {
logrus.WithError(err).Panic("failed to start kafka consumer")
@ -220,3 +207,44 @@ func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
return consumer, producer
}
// setupNaffka creates kafka consumer/producer pair from the config.
func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
var err error
var db *sql.DB
var naffkaDB *naffka.DatabaseImpl
uri, err := url.Parse(string(cfg.Database.Naffka))
if err != nil || uri.Scheme == "file" {
db, err = sql.Open("sqlite3", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err = naffka.NewSqliteDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
} else {
db, err = sql.Open("postgres", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
naffkaDB, err = naffka.NewPostgresqlDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
}
if naffkaDB == nil {
panic("naffka connection string not understood")
}
naff, err := naffka.New(naffkaDB)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka")
}
return naff, naff
}

View file

@ -39,7 +39,7 @@ var ErrRoomNoExists = errors.New("Room does not exist")
// Returns an error if something else went wrong
func BuildEvent(
ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg config.Dendrite, evTime time.Time,
builder *gomatrixserverlib.EventBuilder, cfg *config.Dendrite, evTime time.Time,
queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.Event, error) {
err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes)

View file

@ -21,6 +21,7 @@ import (
"golang.org/x/crypto/ed25519"
"github.com/matrix-org/dendrite/common/keydb/postgres"
"github.com/matrix-org/dendrite/common/keydb/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
@ -44,6 +45,8 @@ func NewDatabase(
switch uri.Scheme {
case "postgres":
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
default:
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
}

View file

@ -117,7 +117,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return results, nil
return results, rows.Err()
}
func (s *serverKeyStatements) upsertServerKeys(

View file

@ -0,0 +1,115 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"math"
"golang.org/x/crypto/ed25519"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
)
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers.
type Database struct {
statements serverKeyStatements
}
// NewDatabase prepares a new key database.
// It creates the necessary tables if they don't already exist.
// It prepares all the SQL statements that it will use.
// Returns an error if there was a problem talking to the database.
func NewDatabase(
dataSourceName string,
serverName gomatrixserverlib.ServerName,
serverKey ed25519.PublicKey,
serverKeyID gomatrixserverlib.KeyID,
) (*Database, error) {
db, err := sql.Open("sqlite3", dataSourceName)
if err != nil {
return nil, err
}
d := &Database{}
err = d.statements.prepare(db)
if err != nil {
return nil, err
}
// Store our own keys so that we don't end up making HTTP requests to find our
// own keys
index := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: serverName,
KeyID: serverKeyID,
}
value := gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: gomatrixserverlib.Base64String(serverKey),
},
ValidUntilTS: math.MaxUint64 >> 1,
ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
}
err = d.StoreKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{
index: value,
},
)
if err != nil {
return nil, err
}
return d, nil
}
// FetcherName implements KeyFetcher
func (d Database) FetcherName() string {
return "KeyDatabase"
}
// FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
return d.statements.bulkSelectServerKeys(ctx, requests)
}
// StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) StoreKeys(
ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
// TODO: Inserting all the keys within a single transaction may
// be more efficient since the transaction overhead can be quite
// high for a single insert statement.
var lastErr error
for request, keys := range keyMap {
if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
// Rather than returning immediately on error we try to insert the
// remaining keys.
// Since we are inserting the keys outside of a transaction it is
// possible for some of the inserts to succeed even though some
// of the inserts have failed.
// Ensuring that we always insert all the keys we can means that
// this behaviour won't depend on the iteration order of the map.
lastErr = err
}
}
return lastErr
}

View file

@ -0,0 +1,142 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
const serverKeysSchema = `
-- A cache of signing keys downloaded from remote servers.
CREATE TABLE IF NOT EXISTS keydb_server_keys (
-- The name of the matrix server the key is for.
server_name TEXT NOT NULL,
-- The ID of the server key.
server_key_id TEXT NOT NULL,
-- Combined server name and key ID separated by the ASCII unit separator
-- to make it easier to run bulk queries.
server_name_and_key_id TEXT NOT NULL,
-- When the key is valid until as a millisecond timestamp.
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
valid_until_ts BIGINT NOT NULL,
-- When the key expired as a millisecond timestamp.
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
expired_ts BIGINT NOT NULL,
-- The base64-encoded public key.
server_key TEXT NOT NULL,
UNIQUE (server_name, server_key_id)
);
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
`
const bulkSelectServerKeysSQL = "" +
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
" server_key FROM keydb_server_keys" +
" WHERE server_name_and_key_id IN ($1)"
const upsertServerKeysSQL = "" +
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (server_name, server_key_id)" +
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverKeyStatements struct {
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
}
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(serverKeysSchema)
if err != nil {
return
}
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
return
}
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
return
}
return
}
func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string
for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
}
stmt := s.bulkSelectServerKeysStmt
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() {
var serverName string
var keyID string
var key string
var validUntilTS int64
var expiredTS int64
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err
}
r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key)
if err != nil {
return nil, err
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return results, nil
}
func (s *serverKeyStatements) upsertServerKeys(
ctx context.Context,
request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult,
) error {
_, err := s.upsertServerKeysStmt.ExecContext(
ctx,
string(request.ServerName),
string(request.KeyID),
nameAndKeyID(request),
key.ValidUntilTS,
key.ExpiredTS,
key.Key.Encode(),
)
return err
}
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
return string(request.ServerName) + "\x1F" + string(request.KeyID)
}

View file

@ -29,7 +29,7 @@ CREATE TABLE IF NOT EXISTS ${prefix}_partition_offsets (
partition INTEGER NOT NULL,
-- The 64-bit offset.
partition_offset BIGINT NOT NULL,
CONSTRAINT ${prefix}_topic_partition_unique UNIQUE (topic, partition)
UNIQUE (topic, partition)
);
`
@ -38,7 +38,7 @@ const selectPartitionOffsetsSQL = "" +
const upsertPartitionOffsetsSQL = "" +
"INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" +
" ON CONFLICT ON CONSTRAINT ${prefix}_topic_partition_unique" +
" ON CONFLICT (topic, partition)" +
" DO UPDATE SET partition_offset = $3"
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
@ -99,7 +99,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
}
results = append(results, offset)
}
return results, nil
return results, rows.Err()
}
// UpsertPartitionOffset updates or inserts the partition offset for the given topic.

View file

@ -16,6 +16,7 @@ package common
import (
"database/sql"
"fmt"
"github.com/lib/pq"
)
@ -30,11 +31,13 @@ type Transaction interface {
// EndTransaction ends a transaction.
// If the transaction succeeded then it is committed, otherwise it is rolledback.
func EndTransaction(txn Transaction, succeeded *bool) {
// You MUST check the error returned from this function to be sure that the transaction
// was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
func EndTransaction(txn Transaction, succeeded *bool) error {
if *succeeded {
txn.Commit() // nolint: errcheck
return txn.Commit() // nolint: errcheck
} else {
txn.Rollback() // nolint: errcheck
return txn.Rollback() // nolint: errcheck
}
}
@ -47,7 +50,12 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
return
}
succeeded := false
defer EndTransaction(txn, &succeeded)
defer func() {
err2 := EndTransaction(txn, &succeeded)
if err == nil && err2 != nil { // failed to commit/rollback
err = err2
}
}()
err = fn(txn)
if err != nil {
@ -74,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}
// Hack of the century
func QueryVariadic(count int) string {
return QueryVariadicOffset(count, 0)
}
func QueryVariadicOffset(count, offset int) string {
str := "("
for i := 0; i < count; i++ {
str += fmt.Sprintf("$%d", i+offset+1)
if i < (count - 1) {
str += ", "
}
}
str += ")"
return str
}

View file

@ -111,6 +111,7 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
// Bind to the same address as the listen address
// All microservices are run on the same host in testing
cfg.Bind.ClientAPI = cfg.Listen.ClientAPI
cfg.Bind.AppServiceAPI = cfg.Listen.AppServiceAPI
cfg.Bind.FederationAPI = cfg.Listen.FederationAPI
cfg.Bind.MediaAPI = cfg.Listen.MediaAPI
cfg.Bind.RoomServer = cfg.Listen.RoomServer

View file

@ -1,9 +1,13 @@
<<<<<<< HEAD
FROM docker.io/golang:1.13.7-alpine3.11
=======
FROM docker.io/golang:1.13.6-alpine
>>>>>>> master
RUN mkdir /build
WORKDIR /build
RUN apk --update --no-cache add openssl bash git
RUN apk --update --no-cache add openssl bash git build-base
CMD ["bash", "docker/build.sh"]

View file

@ -1,13 +1,21 @@
version: "3.4"
services:
riot:
image: vectorim/riot-web
networks:
- internal
ports:
- "8500:80"
monolith:
container_name: dendrite_monolith
hostname: monolith
entrypoint: ["bash", "./docker/services/monolith.sh"]
entrypoint: ["bash", "./docker/services/monolith.sh", "--config", "/etc/dendrite/dendrite.yaml"]
build: ./
volumes:
- ..:/build
- ./build/bin:/build/bin
- ../cfg:/etc/dendrite
networks:
- internal
depends_on:

View file

@ -44,6 +44,7 @@ args:
user: dendrite
database: dendrite
host: 127.0.0.1
sslmode: disable
type: pg
EOF
```

View file

@ -32,8 +32,8 @@ import (
// FederationAPI component.
func SetupFederationAPIComponent(
base *basecomponent.BaseDendrite,
accountsDB *accounts.Database,
deviceDB *devices.Database,
accountsDB accounts.Database,
deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient,
keyRing *gomatrixserverlib.KeyRing,
aliasAPI roomserverAPI.RoomserverAliasAPI,
@ -45,8 +45,8 @@ func SetupFederationAPIComponent(
roomserverProducer := producers.NewRoomserverProducer(inputAPI)
routing.Setup(
base.APIMux, *base.Cfg, queryAPI, aliasAPI, asAPI,
roomserverProducer, federationSenderAPI, *keyRing, federation, accountsDB,
deviceDB,
base.APIMux, base.Cfg, queryAPI, aliasAPI, asAPI,
roomserverProducer, federationSenderAPI, *keyRing,
federation, accountsDB, deviceDB,
)
}

View file

@ -34,7 +34,7 @@ func Backfill(
request *gomatrixserverlib.FederationRequest,
query api.RoomserverQueryAPI,
roomID string,
cfg config.Dendrite,
cfg *config.Dendrite,
) util.JSONResponse {
var res api.QueryBackfillResponse
var eIDs []string

View file

@ -30,7 +30,7 @@ type userDevicesResponse struct {
// GetUserDevices for the given user id
func GetUserDevices(
req *http.Request,
deviceDB *devices.Database,
deviceDB devices.Database,
userID string,
) util.JSONResponse {
localpart, err := userutil.ParseUsernameParam(userID, nil)

View file

@ -32,7 +32,7 @@ func Invite(
request *gomatrixserverlib.FederationRequest,
roomID string,
eventID string,
cfg config.Dendrite,
cfg *config.Dendrite,
producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing,
) util.JSONResponse {

View file

@ -33,7 +33,7 @@ import (
func MakeJoin(
httpReq *http.Request,
request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite,
cfg *config.Dendrite,
query api.RoomserverQueryAPI,
roomID, userID string,
) util.JSONResponse {
@ -97,7 +97,7 @@ func MakeJoin(
func SendJoin(
httpReq *http.Request,
request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite,
cfg *config.Dendrite,
query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing,

View file

@ -27,7 +27,7 @@ import (
// LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg config.Dendrite) util.JSONResponse {
func LocalKeys(cfg *config.Dendrite) util.JSONResponse {
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
if err != nil {
return util.ErrorResponse(err)
@ -35,7 +35,7 @@ func LocalKeys(cfg config.Dendrite) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: keys}
}
func localKeys(cfg config.Dendrite, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
func localKeys(cfg *config.Dendrite, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
var keys gomatrixserverlib.ServerKeys
keys.ServerName = cfg.Matrix.ServerName

View file

@ -31,7 +31,7 @@ import (
func MakeLeave(
httpReq *http.Request,
request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite,
cfg *config.Dendrite,
query api.RoomserverQueryAPI,
roomID, userID string,
) util.JSONResponse {
@ -95,7 +95,7 @@ func MakeLeave(
func SendLeave(
httpReq *http.Request,
request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite,
cfg *config.Dendrite,
producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing,
roomID, eventID string,

View file

@ -30,8 +30,8 @@ import (
// GetProfile implements GET /_matrix/federation/v1/query/profile
func GetProfile(
httpReq *http.Request,
accountDB *accounts.Database,
cfg config.Dendrite,
accountDB accounts.Database,
cfg *config.Dendrite,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field")

View file

@ -32,7 +32,7 @@ import (
func RoomAliasToID(
httpReq *http.Request,
federation *gomatrixserverlib.FederationClient,
cfg config.Dendrite,
cfg *config.Dendrite,
aliasAPI roomserverAPI.RoomserverAliasAPI,
senderAPI federationSenderAPI.FederationSenderQueryAPI,
) util.JSONResponse {

View file

@ -43,7 +43,7 @@ const (
// nolint: gocyclo
func Setup(
apiMux *mux.Router,
cfg config.Dendrite,
cfg *config.Dendrite,
query roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
@ -51,8 +51,8 @@ func Setup(
federationSenderAPI federationSenderAPI.FederationSenderQueryAPI,
keys gomatrixserverlib.KeyRing,
federation *gomatrixserverlib.FederationClient,
accountDB *accounts.Database,
deviceDB *devices.Database,
accountDB accounts.Database,
deviceDB devices.Database,
) {
v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter()
v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter()

View file

@ -34,7 +34,7 @@ func Send(
httpReq *http.Request,
request *gomatrixserverlib.FederationRequest,
txnID gomatrixserverlib.TransactionID,
cfg config.Dendrite,
cfg *config.Dendrite,
query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing,

View file

@ -59,9 +59,9 @@ var (
// CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind
func CreateInvitesFrom3PIDInvites(
req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI,
asAPI appserviceAPI.AppServiceQueryAPI, cfg config.Dendrite,
asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite,
producer *producers.RoomserverProducer, federation *gomatrixserverlib.FederationClient,
accountDB *accounts.Database,
accountDB accounts.Database,
) util.JSONResponse {
var body invites
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
@ -98,7 +98,7 @@ func ExchangeThirdPartyInvite(
request *gomatrixserverlib.FederationRequest,
roomID string,
queryAPI roomserverAPI.RoomserverQueryAPI,
cfg config.Dendrite,
cfg *config.Dendrite,
federation *gomatrixserverlib.FederationClient,
producer *producers.RoomserverProducer,
) util.JSONResponse {
@ -172,9 +172,9 @@ func ExchangeThirdPartyInvite(
// necessary data to do so.
func createInviteFrom3PIDInvite(
ctx context.Context, queryAPI roomserverAPI.RoomserverQueryAPI,
asAPI appserviceAPI.AppServiceQueryAPI, cfg config.Dendrite,
asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite,
inv invite, federation *gomatrixserverlib.FederationClient,
accountDB *accounts.Database,
accountDB accounts.Database,
) (*gomatrixserverlib.Event, error) {
_, server, err := gomatrixserverlib.SplitID('@', inv.MXID)
if err != nil {
@ -230,7 +230,7 @@ func createInviteFrom3PIDInvite(
func buildMembershipEvent(
ctx context.Context,
builder *gomatrixserverlib.EventBuilder, queryAPI roomserverAPI.RoomserverQueryAPI,
cfg config.Dendrite,
cfg *config.Dendrite,
) (*gomatrixserverlib.Event, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil {
@ -290,7 +290,7 @@ func buildMembershipEvent(
// them responded with an error.
func sendToRemoteServer(
ctx context.Context, inv invite,
federation *gomatrixserverlib.FederationClient, _ config.Dendrite,
federation *gomatrixserverlib.FederationClient, _ *config.Dendrite,
builder gomatrixserverlib.EventBuilder,
) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2)

View file

@ -132,5 +132,5 @@ func joinedHostsFromStmt(
})
}
return result, nil
return result, rows.Err()
}

View file

@ -87,7 +87,7 @@ func (d *Database) UpdateRoom(
return nil
}
if lastSentEventID != oldEventID {
if lastSentEventID != "" && lastSentEventID != oldEventID {
return types.EventIDMismatchError{
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
}

View file

@ -0,0 +1,139 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/gomatrixserverlib"
)
const joinedHostsSchema = `
-- The joined_hosts table stores a list of m.room.member event ids in the
-- current state for each room where the membership is "join".
-- There will be an entry for every user that is joined to the room.
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
-- The string ID of the room.
room_id TEXT NOT NULL,
-- The event ID of the m.room.member join event.
event_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
ON federationsender_joined_hosts (event_id);
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
ON federationsender_joined_hosts (room_id)
`
const insertJoinedHostsSQL = "" +
"INSERT INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
" VALUES ($1, $2, $3)"
const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1"
type joinedHostsStatements struct {
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
}
func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(joinedHostsSchema)
if err != nil {
return
}
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
return
}
func (s *joinedHostsStatements) insertJoinedHosts(
ctx context.Context,
txn *sql.Tx,
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error {
stmt := common.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err
}
func (s *joinedHostsStatements) deleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
for _, eventID := range eventIDs {
stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt)
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
return err
}
}
return nil
}
func (s *joinedHostsStatements) selectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
stmt := common.TxStmt(txn, s.selectJoinedHostsStmt)
return joinedHostsFromStmt(ctx, stmt, roomID)
}
func (s *joinedHostsStatements) selectJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error) {
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
}
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
var result []types.JoinedHost
for rows.Next() {
var eventID, serverName string
if err = rows.Scan(&eventID, &serverName); err != nil {
return nil, err
}
result = append(result, types.JoinedHost{
MemberEventID: eventID,
ServerName: gomatrixserverlib.ServerName(serverName),
})
}
return result, nil
}

View file

@ -0,0 +1,101 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
)
const roomSchema = `
CREATE TABLE IF NOT EXISTS federationsender_rooms (
-- The string ID of the room
room_id TEXT PRIMARY KEY,
-- The most recent event state by the room server.
-- We can use this to tell if our view of the room state has become
-- desynchronised.
last_event_id TEXT NOT NULL
);`
const insertRoomSQL = "" +
"INSERT INTO federationsender_rooms (room_id, last_event_id) VALUES ($1, '')" +
" ON CONFLICT DO NOTHING"
const selectRoomForUpdateSQL = "" +
"SELECT last_event_id FROM federationsender_rooms WHERE room_id = $1"
const updateRoomSQL = "" +
"UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1"
type roomStatements struct {
insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt
}
func (s *roomStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(roomSchema)
if err != nil {
return
}
if s.insertRoomStmt, err = db.Prepare(insertRoomSQL); err != nil {
return
}
if s.selectRoomForUpdateStmt, err = db.Prepare(selectRoomForUpdateSQL); err != nil {
return
}
if s.updateRoomStmt, err = db.Prepare(updateRoomSQL); err != nil {
return
}
return
}
// insertRoom inserts the room if it didn't already exist.
// If the room didn't exist then last_event_id is set to the empty string.
func (s *roomStatements) insertRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
// The row must already exist in the table. Callers can ensure that the row
// exists by calling insertRoom first.
func (s *roomStatements) selectRoomForUpdate(
ctx context.Context, txn *sql.Tx, roomID string,
) (string, error) {
var lastEventID string
stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
if err != nil {
return "", err
}
return lastEventID, nil
}
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
// have already been called earlier within the transaction.
func (s *roomStatements) updateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error {
stmt := common.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err
}

View file

@ -0,0 +1,124 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/types"
)
// Database stores information needed by the federation sender
type Database struct {
joinedHostsStatements
roomStatements
common.PartitionOffsetStatements
db *sql.DB
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sql.Open("sqlite3", dataSourceName); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
var err error
if err = d.joinedHostsStatements.prepare(d.db); err != nil {
return err
}
if err = d.roomStatements.prepare(d.db); err != nil {
return err
}
return d.PartitionOffsetStatements.Prepare(d.db, "federationsender")
}
// UpdateRoom updates the joined hosts for a room and returns what the joined
// hosts were before the update, or nil if this was a duplicate message.
// This is called when we receive a message from kafka, so we pass in
// oldEventID and newEventID to check that we haven't missed any messages or
// this isn't a duplicate message.
func (d *Database) UpdateRoom(
ctx context.Context,
roomID, oldEventID, newEventID string,
addHosts []types.JoinedHost,
removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
err = d.insertRoom(ctx, txn, roomID)
if err != nil {
return err
}
lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID)
if err != nil {
return err
}
if lastSentEventID == newEventID {
// We've handled this message before, so let's just ignore it.
// We can only get a duplicate for the last message we processed,
// so its enough just to compare the newEventID with lastSentEventID
return nil
}
if lastSentEventID != "" && lastSentEventID != oldEventID {
return types.EventIDMismatchError{
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
}
}
joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID)
if err != nil {
return err
}
for _, add := range addHosts {
err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
if err != nil {
return err
}
}
if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil {
return err
}
return d.updateRoom(ctx, txn, roomID, newEventID)
})
return
}
// GetJoinedHosts returns the currently joined hosts for room,
// as known to federationserver.
// Returns an error if something goes wrong.
func (d *Database) GetJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error) {
return d.selectJoinedHosts(ctx, roomID)
}

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/storage/postgres"
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3"
"github.com/matrix-org/dendrite/federationsender/types"
)
@ -36,6 +37,8 @@ func NewDatabase(dataSourceName string) (Database, error) {
return postgres.NewDatabase(dataSourceName)
}
switch uri.Scheme {
case "file":
return sqlite3.NewDatabase(dataSourceName)
case "postgres":
return postgres.NewDatabase(dataSourceName)
default:

22
go.mod
View file

@ -1,32 +1,30 @@
module github.com/matrix-org/dendrite
require (
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3
github.com/DataDog/zstd v1.4.4 // indirect
github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42 // indirect
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934 // indirect
github.com/eapache/go-resiliency v1.2.0 // indirect
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect
github.com/eapache/queue v1.1.0 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/frankban/quicktest v1.7.2 // indirect
github.com/golang/snappy v0.0.1 // indirect
github.com/gorilla/mux v1.7.3
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 // indirect
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/lib/pq v1.2.0
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0
github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5
github.com/mattn/go-sqlite3 v2.0.2+incompatible
github.com/miekg/dns v1.1.12 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5
github.com/opentracing/opentracing-go v1.0.2
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac // indirect
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897 // indirect
github.com/pierrec/lz4 v2.4.1+incompatible // indirect
github.com/pkg/errors v0.8.1
github.com/prometheus/client_golang v1.2.1
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 // indirect
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 // indirect
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0 // indirect
github.com/uber-go/atomic v1.3.0 // indirect
@ -36,7 +34,7 @@ require (
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 // indirect
gopkg.in/Shopify/sarama.v1 v1.11.0
gopkg.in/Shopify/sarama.v1 v1.20.1
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/h2non/bimg.v1 v1.0.18
gopkg.in/yaml.v2 v2.2.2

43
go.sum
View file

@ -1,5 +1,5 @@
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3 h1:j6BAEHYn1kUyW2j7kY0mOJ/R8A0qWwXpvUAEHGemm/g=
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE=
github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
@ -18,13 +18,15 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42 h1:f8ERmXYuaC+kCSv2w+y3rBK/oVu6If4DEm3jywJJ0hc=
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934 h1:oGLoaVIefp3tiOgi7+KInR/nNPvEpPM6GFo+El7fd14=
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q=
github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k=
github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk=
github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
@ -35,11 +37,13 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
@ -48,8 +52,6 @@ github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplb
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg=
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@ -69,10 +71,12 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 h1:kmRjpmFOenVpOaV/DRlo9p6z/IbOKlUC+hhKsAAh8Qg=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI=
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A=
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1 h1:osLoFdOy+ChQqVUn2PeTDETFftVkl4w9t/OW18g3lnk=
github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8=
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA=
github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U=
github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0=
@ -90,10 +94,8 @@ github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/R
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac h1:tKcxwAA5OHUQjL6sWsuCIcP9OnzN+RwKfvomtIOsfy8=
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897 h1:jp3jc/PyyTrTKjJJ6rWnhTbmo7tGgBFyG9AL5FIrO1I=
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
github.com/pierrec/lz4 v2.4.1+incompatible h1:mFe7ttWaflA46Mhqh+jUfjp2qTbPYxLB2/OyBppH9dg=
github.com/pierrec/lz4 v2.4.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -114,8 +116,8 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.0.5 h1:3+auTFlqw+ZaQYJARz6ArODtkaIwtvBTx3N2NehQlL8=
github.com/prometheus/procfs v0.0.5/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ=
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 h1:gwcdIpH6NU2iF8CmcqD+CP6+1CkRBOhHaPR+iu6raBY=
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 h1:dY6ETXrvDG7Sa4vE8ZQG4yqWg6UnOcbqTAahkV813vQ=
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME=
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
@ -172,8 +174,8 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY=
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0=
gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
gopkg.in/Shopify/sarama.v1 v1.20.1 h1:Gi09A3fJXm0Jgt8kuKZ8YK+r60GfYn7MQuEmI3oq6hE=
gopkg.in/Shopify/sarama.v1 v1.20.1/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@ -183,6 +185,7 @@ gopkg.in/h2non/bimg.v1 v1.0.18 h1:qn6/RpBHt+7WQqoBcK+aF2puc6nC78eZj5LexxoalT4=
gopkg.in/h2non/bimg.v1 v1.0.18/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So=
gopkg.in/h2non/gock.v1 v1.0.14 h1:fTeu9fcUvSnLNacYvYI54h+1/XEteDyHvrVCZEEEYNM=
gopkg.in/h2non/gock.v1 v1.0.14/go.mod h1:sX4zAkdYX1TRGJ2JY156cFspQn4yRWn6p9EMdODlynE=
gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI=
gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=

View file

@ -27,7 +27,7 @@ import (
// component.
func SetupMediaAPIComponent(
base *basecomponent.BaseDendrite,
deviceDB *devices.Database,
deviceDB devices.Database,
) {
mediaDB, err := storage.Open(string(base.Cfg.Database.MediaAPI))
if err != nil {

View file

@ -44,7 +44,7 @@ func Setup(
apiMux *mux.Router,
cfg *config.Dendrite,
db storage.Database,
deviceDB *devices.Database,
deviceDB devices.Database,
client *gomatrixserverlib.Client,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()

View file

@ -144,6 +144,7 @@ func (s *thumbnailStatements) selectThumbnails(
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
var thumbnails []*types.ThumbnailMetadata
for rows.Next() {
@ -167,5 +168,5 @@ func (s *thumbnailStatements) selectThumbnails(
thumbnails = append(thumbnails, &thumbnailMetadata)
}
return thumbnails, err
return thumbnails, rows.Err()
}

View file

@ -28,7 +28,7 @@ import (
// component.
func SetupPublicRoomsAPIComponent(
base *basecomponent.BaseDendrite,
deviceDB *devices.Database,
deviceDB devices.Database,
rsQueryAPI roomserverAPI.RoomserverQueryAPI,
) {
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))

View file

@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
// Due to Setup being used to call many other functions, a gocyclo nolint is
// applied:
// nolint: gocyclo
func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB storage.Database) {
func Setup(apiMux *mux.Router, deviceDB devices.Database, publicRoomsDB storage.Database) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{

View file

@ -203,6 +203,7 @@ func (s *publicRoomsStatements) selectPublicRooms(
if err != nil {
return []types.PublicRoom{}, nil
}
defer rows.Close() // nolint: errcheck
rooms := []types.PublicRoom{}
for rows.Next() {
@ -222,7 +223,7 @@ func (s *publicRoomsStatements) selectPublicRooms(
rooms = append(rooms, r)
}
return rooms, nil
return rooms, rows.Err()
}
func (s *publicRoomsStatements) selectRoomVisibility(

View file

@ -0,0 +1,36 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -0,0 +1,277 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/publicroomsapi/types"
)
var editableAttributes = []string{
"aliases",
"canonical_alias",
"name",
"topic",
"world_readable",
"guest_can_join",
"avatar_url",
"visibility",
}
const publicRoomsSchema = `
-- Stores all of the rooms with data needed to create the server's room directory
CREATE TABLE IF NOT EXISTS publicroomsapi_public_rooms(
-- The room's ID
room_id TEXT NOT NULL PRIMARY KEY,
-- Number of joined members in the room
joined_members INTEGER NOT NULL DEFAULT 0,
-- Aliases of the room (empty array if none)
aliases TEXT[] NOT NULL DEFAULT '{}'::TEXT[],
-- Canonical alias of the room (empty string if none)
canonical_alias TEXT NOT NULL DEFAULT '',
-- Name of the room (empty string if none)
name TEXT NOT NULL DEFAULT '',
-- Topic of the room (empty string if none)
topic TEXT NOT NULL DEFAULT '',
-- Is the room world readable?
world_readable BOOLEAN NOT NULL DEFAULT false,
-- Can guest join the room?
guest_can_join BOOLEAN NOT NULL DEFAULT false,
-- URL of the room avatar (empty string if none)
avatar_url TEXT NOT NULL DEFAULT '',
-- Visibility of the room: true means the room is publicly visible, false
-- means the room is private
visibility BOOLEAN NOT NULL DEFAULT false
);
`
const countPublicRoomsSQL = "" +
"SELECT COUNT(*) FROM publicroomsapi_public_rooms" +
" WHERE visibility = true"
const selectPublicRoomsSQL = "" +
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
" FROM publicroomsapi_public_rooms WHERE visibility = true" +
" ORDER BY joined_members DESC" +
" OFFSET $1"
const selectPublicRoomsWithLimitSQL = "" +
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
" FROM publicroomsapi_public_rooms WHERE visibility = true" +
" ORDER BY joined_members DESC" +
" OFFSET $1 LIMIT $2"
const selectPublicRoomsWithFilterSQL = "" +
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
" FROM publicroomsapi_public_rooms" +
" WHERE visibility = true" +
" AND (LOWER(name) LIKE LOWER($1)" +
" OR LOWER(topic) LIKE LOWER($1)" +
" OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" +
" ORDER BY joined_members DESC" +
" OFFSET $2"
const selectPublicRoomsWithLimitAndFilterSQL = "" +
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
" FROM publicroomsapi_public_rooms" +
" WHERE visibility = true" +
" AND (LOWER(name) LIKE LOWER($1)" +
" OR LOWER(topic) LIKE LOWER($1)" +
" OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" +
" ORDER BY joined_members DESC" +
" OFFSET $2 LIMIT $3"
const selectRoomVisibilitySQL = "" +
"SELECT visibility FROM publicroomsapi_public_rooms" +
" WHERE room_id = $1"
const insertNewRoomSQL = "" +
"INSERT INTO publicroomsapi_public_rooms(room_id)" +
" VALUES ($1)"
const incrementJoinedMembersInRoomSQL = "" +
"UPDATE publicroomsapi_public_rooms" +
" SET joined_members = joined_members + 1" +
" WHERE room_id = $1"
const decrementJoinedMembersInRoomSQL = "" +
"UPDATE publicroomsapi_public_rooms" +
" SET joined_members = joined_members - 1" +
" WHERE room_id = $1"
const updateRoomAttributeSQL = "" +
"UPDATE publicroomsapi_public_rooms" +
" SET %s = $1" +
" WHERE room_id = $2"
type publicRoomsStatements struct {
countPublicRoomsStmt *sql.Stmt
selectPublicRoomsStmt *sql.Stmt
selectPublicRoomsWithLimitStmt *sql.Stmt
selectPublicRoomsWithFilterStmt *sql.Stmt
selectPublicRoomsWithLimitAndFilterStmt *sql.Stmt
selectRoomVisibilityStmt *sql.Stmt
insertNewRoomStmt *sql.Stmt
incrementJoinedMembersInRoomStmt *sql.Stmt
decrementJoinedMembersInRoomStmt *sql.Stmt
updateRoomAttributeStmts map[string]*sql.Stmt
}
func (s *publicRoomsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(publicRoomsSchema)
if err != nil {
return
}
stmts := statementList{
{&s.countPublicRoomsStmt, countPublicRoomsSQL},
{&s.selectPublicRoomsStmt, selectPublicRoomsSQL},
{&s.selectPublicRoomsWithLimitStmt, selectPublicRoomsWithLimitSQL},
{&s.selectPublicRoomsWithFilterStmt, selectPublicRoomsWithFilterSQL},
{&s.selectPublicRoomsWithLimitAndFilterStmt, selectPublicRoomsWithLimitAndFilterSQL},
{&s.selectRoomVisibilityStmt, selectRoomVisibilitySQL},
{&s.insertNewRoomStmt, insertNewRoomSQL},
{&s.incrementJoinedMembersInRoomStmt, incrementJoinedMembersInRoomSQL},
{&s.decrementJoinedMembersInRoomStmt, decrementJoinedMembersInRoomSQL},
}
if err = stmts.prepare(db); err != nil {
return
}
s.updateRoomAttributeStmts = make(map[string]*sql.Stmt)
for _, editable := range editableAttributes {
stmt := fmt.Sprintf(updateRoomAttributeSQL, editable)
if s.updateRoomAttributeStmts[editable], err = db.Prepare(stmt); err != nil {
return
}
}
return
}
func (s *publicRoomsStatements) countPublicRooms(ctx context.Context) (nb int64, err error) {
err = s.countPublicRoomsStmt.QueryRowContext(ctx).Scan(&nb)
return
}
func (s *publicRoomsStatements) selectPublicRooms(
ctx context.Context, offset int64, limit int16, filter string,
) ([]types.PublicRoom, error) {
var rows *sql.Rows
var err error
if len(filter) > 0 {
pattern := "%" + filter + "%"
if limit == 0 {
rows, err = s.selectPublicRoomsWithFilterStmt.QueryContext(
ctx, pattern, offset,
)
} else {
rows, err = s.selectPublicRoomsWithLimitAndFilterStmt.QueryContext(
ctx, pattern, offset, limit,
)
}
} else {
if limit == 0 {
rows, err = s.selectPublicRoomsStmt.QueryContext(ctx, offset)
} else {
rows, err = s.selectPublicRoomsWithLimitStmt.QueryContext(
ctx, offset, limit,
)
}
}
if err != nil {
return []types.PublicRoom{}, nil
}
rooms := []types.PublicRoom{}
for rows.Next() {
var r types.PublicRoom
var aliases pq.StringArray
err = rows.Scan(
&r.RoomID, &r.NumJoinedMembers, &aliases, &r.CanonicalAlias,
&r.Name, &r.Topic, &r.WorldReadable, &r.GuestCanJoin, &r.AvatarURL,
)
if err != nil {
return rooms, err
}
r.Aliases = aliases
rooms = append(rooms, r)
}
return rooms, nil
}
func (s *publicRoomsStatements) selectRoomVisibility(
ctx context.Context, roomID string,
) (v bool, err error) {
err = s.selectRoomVisibilityStmt.QueryRowContext(ctx, roomID).Scan(&v)
return
}
func (s *publicRoomsStatements) insertNewRoom(
ctx context.Context, roomID string,
) error {
_, err := s.insertNewRoomStmt.ExecContext(ctx, roomID)
return err
}
func (s *publicRoomsStatements) incrementJoinedMembersInRoom(
ctx context.Context, roomID string,
) error {
_, err := s.incrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
return err
}
func (s *publicRoomsStatements) decrementJoinedMembersInRoom(
ctx context.Context, roomID string,
) error {
_, err := s.decrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
return err
}
func (s *publicRoomsStatements) updateRoomAttribute(
ctx context.Context, attrName string, attrValue attributeValue, roomID string,
) error {
stmt, isEditable := s.updateRoomAttributeStmts[attrName]
if !isEditable {
return errors.New("Cannot edit " + attrName)
}
var value interface{}
switch v := attrValue.(type) {
case []string:
value = pq.StringArray(v)
case bool, string:
value = attrValue
default:
return errors.New("Unsupported attribute type, must be bool, string or []string")
}
_, err := stmt.ExecContext(ctx, value, roomID)
return err
}

View file

@ -0,0 +1,256 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
_ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/publicroomsapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
// PublicRoomsServerDatabase represents a public rooms server database.
type PublicRoomsServerDatabase struct {
db *sql.DB
common.PartitionOffsetStatements
statements publicRoomsStatements
}
type attributeValue interface{}
// NewPublicRoomsServerDatabase creates a new public rooms server database.
func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) {
var db *sql.DB
var err error
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
return nil, err
}
storage := PublicRoomsServerDatabase{
db: db,
}
if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil {
return nil, err
}
if err = storage.statements.prepare(db); err != nil {
return nil, err
}
return &storage, nil
}
// GetRoomVisibility returns the room visibility as a boolean: true if the room
// is publicly visible, false if not.
// Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) GetRoomVisibility(
ctx context.Context, roomID string,
) (bool, error) {
return d.statements.selectRoomVisibility(ctx, roomID)
}
// SetRoomVisibility updates the visibility attribute of a room. This attribute
// must be set to true if the room is publicly visible, false if not.
// Returns an error if the update failed.
func (d *PublicRoomsServerDatabase) SetRoomVisibility(
ctx context.Context, visible bool, roomID string,
) error {
return d.statements.updateRoomAttribute(ctx, "visibility", visible, roomID)
}
// CountPublicRooms returns the number of room set as publicly visible on the server.
// Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) {
return d.statements.countPublicRooms(ctx)
}
// GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number
// of joined members. This array can be limited by a given number of elements, and offset by a given value.
// If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all
// the rooms set as publicly visible on the server.
// Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) GetPublicRooms(
ctx context.Context, offset int64, limit int16, filter string,
) ([]types.PublicRoom, error) {
return d.statements.selectPublicRooms(ctx, offset, limit, filter)
}
// UpdateRoomFromEvents iterate over a slice of state events and call
// UpdateRoomFromEvent on each of them to update the database representation of
// the rooms updated by each event.
// The slice of events to remove is used to update the number of joined members
// for the room in the database.
// If the update triggered by one of the events failed, aborts the process and
// returns an error.
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents(
ctx context.Context,
eventsToAdd []gomatrixserverlib.Event,
eventsToRemove []gomatrixserverlib.Event,
) error {
for _, event := range eventsToAdd {
if err := d.UpdateRoomFromEvent(ctx, event); err != nil {
return err
}
}
for _, event := range eventsToRemove {
if event.Type() == "m.room.member" {
if err := d.updateNumJoinedUsers(ctx, event, true); err != nil {
return err
}
}
}
return nil
}
// UpdateRoomFromEvent updates the database representation of a room from a Matrix event, by
// checking the event's type to know which attribute to change and using the event's content
// to define the new value of the attribute.
// If the event doesn't match with any property used to compute the public room directory,
// does nothing.
// If something went wrong during the process, returns an error.
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(
ctx context.Context, event gomatrixserverlib.Event,
) error {
// Process the event according to its type
switch event.Type() {
case "m.room.create":
return d.statements.insertNewRoom(ctx, event.RoomID())
case "m.room.member":
return d.updateNumJoinedUsers(ctx, event, false)
case "m.room.aliases":
return d.updateRoomAliases(ctx, event)
case "m.room.canonical_alias":
var content common.CanonicalAliasContent
field := &(content.Alias)
attrName := "canonical_alias"
return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.name":
var content common.NameContent
field := &(content.Name)
attrName := "name"
return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.topic":
var content common.TopicContent
field := &(content.Topic)
attrName := "topic"
return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.avatar":
var content common.AvatarContent
field := &(content.URL)
attrName := "avatar_url"
return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.history_visibility":
var content common.HistoryVisibilityContent
field := &(content.HistoryVisibility)
attrName := "world_readable"
strForTrue := "world_readable"
return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
case "m.room.guest_access":
var content common.GuestAccessContent
field := &(content.GuestAccess)
attrName := "guest_can_join"
strForTrue := "can_join"
return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
}
// If the event type didn't match, return with no error
return nil
}
// updateNumJoinedUsers updates the number of joined user in the database representation
// of a room using a given "m.room.member" Matrix event.
// If the membership property of the event isn't "join", ignores it and returs nil.
// If the remove parameter is set to false, increments the joined members counter in the
// database, if set to truem decrements it.
// Returns an error if the update failed.
func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
ctx context.Context, membershipEvent gomatrixserverlib.Event, remove bool,
) error {
membership, err := membershipEvent.Membership()
if err != nil {
return err
}
if membership != gomatrixserverlib.Join {
return nil
}
if remove {
return d.statements.decrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
}
return d.statements.incrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
}
// updateStringAttribute updates a given string attribute in the database
// representation of a room using a given string data field from content of the
// Matrix event triggering the update.
// Returns an error if decoding the Matrix event's content or updating the attribute
// failed.
func (d *PublicRoomsServerDatabase) updateStringAttribute(
ctx context.Context, attrName string, event gomatrixserverlib.Event,
content interface{}, field *string,
) error {
if err := json.Unmarshal(event.Content(), content); err != nil {
return err
}
return d.statements.updateRoomAttribute(ctx, attrName, *field, event.RoomID())
}
// updateBooleanAttribute updates a given boolean attribute in the database
// representation of a room using a given string data field from content of the
// Matrix event triggering the update.
// The attribute is set to true if the field matches a given string, false if not.
// Returns an error if decoding the Matrix event's content or updating the attribute
// failed.
func (d *PublicRoomsServerDatabase) updateBooleanAttribute(
ctx context.Context, attrName string, event gomatrixserverlib.Event,
content interface{}, field *string, strForTrue string,
) error {
if err := json.Unmarshal(event.Content(), content); err != nil {
return err
}
var attrValue bool
if *field == strForTrue {
attrValue = true
} else {
attrValue = false
}
return d.statements.updateRoomAttribute(ctx, attrName, attrValue, event.RoomID())
}
// updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of
// a given room with it.
// Returns an error if decoding the Matrix event or updating the list failed.
func (d *PublicRoomsServerDatabase) updateRoomAliases(
ctx context.Context, aliasesEvent gomatrixserverlib.Event,
) error {
var content common.AliasesContent
if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil {
return err
}
return d.statements.updateRoomAttribute(
ctx, "aliases", content.Aliases, aliasesEvent.RoomID(),
)
}

View file

@ -196,7 +196,12 @@ func processInviteEvent(
return err
}
succeeded := false
defer common.EndTransaction(updater, &succeeded)
defer func() {
txerr := common.EndTransaction(updater, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
if updater.IsJoin() {
// If the user is joined to the room then that takes precedence over this

View file

@ -60,7 +60,12 @@ func updateLatestEvents(
return
}
succeeded := false
defer common.EndTransaction(updater, &succeeded)
defer func() {
txerr := common.EndTransaction(updater, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
u := latestEventsUpdater{
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,

View file

@ -102,5 +102,5 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
}
result.EventNID = types.EventNID(eventNID)
}
return results[:i], nil
return results[:i], rows.Err()
}

View file

@ -125,7 +125,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
}
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
}
return result, nil
return result, rows.Err()
}
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
@ -150,5 +150,5 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
}
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
}
return result, nil
return result, rows.Err()
}

View file

@ -143,5 +143,5 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
}
result[eventType] = types.EventTypeNID(eventTypeNID)
}
return result, nil
return result, rows.Err()
}

View file

@ -209,6 +209,9 @@ func (s *eventStatements) bulkSelectStateEventByID(
return nil, err
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query.
@ -219,7 +222,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
)
}
return results, err
return results, nil
}
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
@ -251,12 +254,15 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
)
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventIDs) {
return nil, types.MissingEventError(
fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
)
}
return results, err
return results, nil
}
func (s *eventStatements) updateEventState(
@ -321,6 +327,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
result.EventID = eventID
result.EventSHA256 = eventSHA256
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
@ -343,6 +352,9 @@ func (s *eventStatements) bulkSelectEventReference(
return nil, err
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
@ -366,6 +378,9 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ
}
results[types.EventNID(eventNID)] = eventID
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
@ -389,7 +404,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str
}
results[eventID] = types.EventNID(eventNID)
}
return results, nil
return results, rows.Err()
}
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {

View file

@ -114,21 +114,23 @@ func (s *inviteStatements) insertInviteEvent(
func (s *inviteStatements) updateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
) ([]string, error) {
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil {
return nil, err
}
defer (func() { err = rows.Close() })()
defer rows.Close() // nolint: errcheck
var eventIDs []string
for rows.Next() {
var inviteEventID string
if err := rows.Scan(&inviteEventID); err != nil {
if err = rows.Scan(&inviteEventID); err != nil {
return nil, err
}
eventIDs = append(eventIDs, inviteEventID)
}
return
return eventIDs, rows.Err()
}
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
@ -151,5 +153,5 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom(
}
result = append(result, types.EventStateKeyNID(senderUserNID))
}
return result, nil
return result, rows.Err()
}

View file

@ -151,6 +151,7 @@ func (s *membershipStatements) selectMembershipsFromRoom(
if err != nil {
return
}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var eNID types.EventNID
@ -159,8 +160,9 @@ func (s *membershipStatements) selectMembershipsFromRoom(
}
eventNIDs = append(eventNIDs, eNID)
}
return
return eventNIDs, rows.Err()
}
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context,
roomNID types.RoomNID, membership membershipState,
@ -170,6 +172,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
if err != nil {
return
}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var eNID types.EventNID
@ -178,7 +181,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
}
eventNIDs = append(eventNIDs, eNID)
}
return
return eventNIDs, rows.Err()
}
func (s *membershipStatements) updateMembership(

View file

@ -90,23 +90,23 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias(
func (s *roomAliasesStatements) selectAliasesFromRoomID(
ctx context.Context, roomID string,
) (aliases []string, err error) {
aliases = []string{}
) ([]string, error) {
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
if err != nil {
return
return nil, err
}
defer rows.Close() // nolint: errcheck
var aliases []string
for rows.Next() {
var alias string
if err = rows.Scan(&alias); err != nil {
return
return nil, err
}
aliases = append(aliases, alias)
}
return
return aliases, rows.Err()
}
func (s *roomAliasesStatements) selectCreatorIDFromAlias(

View file

@ -152,7 +152,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
if err = rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
@ -169,10 +169,13 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
}
current.StateEntries = append(current.StateEntries, entry)
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateBlockNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
}
return results, nil
return results, err
}
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
@ -237,7 +240,7 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
if current.StateEntries != nil {
results = append(results, current)
}
return results, nil
return results, rows.Err()
}
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {

View file

@ -104,7 +104,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
@ -112,6 +112,9 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
}

Some files were not shown because too many files have changed in this diff Show more