mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-16 11:23:11 -06:00
Merge branch 'master' of https://github.com/matrix-org/dendrite into basicauth-metrics
This commit is contained in:
commit
2ccc149836
|
|
@ -140,7 +140,7 @@ func RetrieveUserProfile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID string,
|
userID string,
|
||||||
asAPI AppServiceQueryAPI,
|
asAPI AppServiceQueryAPI,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -41,8 +41,8 @@ import (
|
||||||
// component.
|
// component.
|
||||||
func SetupAppServiceAPIComponent(
|
func SetupAppServiceAPIComponent(
|
||||||
base *basecomponent.BaseDendrite,
|
base *basecomponent.BaseDendrite,
|
||||||
accountsDB *accounts.Database,
|
accountsDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
roomserverAliasAPI roomserverAPI.RoomserverAliasAPI,
|
roomserverAliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
roomserverQueryAPI roomserverAPI.RoomserverQueryAPI,
|
roomserverQueryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
|
|
@ -100,7 +100,7 @@ func SetupAppServiceAPIComponent(
|
||||||
|
|
||||||
// Set up HTTP Endpoints
|
// Set up HTTP Endpoints
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
base.APIMux, *base.Cfg, roomserverQueryAPI, roomserverAliasAPI,
|
base.APIMux, base.Cfg, roomserverQueryAPI, roomserverAliasAPI,
|
||||||
accountsDB, federation, transactionsCache,
|
accountsDB, federation, transactionsCache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -111,8 +111,8 @@ func SetupAppServiceAPIComponent(
|
||||||
// `sender_localpart` field of each application service if it doesn't
|
// `sender_localpart` field of each application service if it doesn't
|
||||||
// exist already
|
// exist already
|
||||||
func generateAppServiceAccount(
|
func generateAppServiceAccount(
|
||||||
accountsDB *accounts.Database,
|
accountsDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
as config.ApplicationService,
|
as config.ApplicationService,
|
||||||
) error {
|
) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ import (
|
||||||
// OutputRoomEventConsumer consumes events that originated in the room server.
|
// OutputRoomEventConsumer consumes events that originated in the room server.
|
||||||
type OutputRoomEventConsumer struct {
|
type OutputRoomEventConsumer struct {
|
||||||
roomServerConsumer *common.ContinualConsumer
|
roomServerConsumer *common.ContinualConsumer
|
||||||
db *accounts.Database
|
db accounts.Database
|
||||||
asDB *storage.Database
|
asDB *storage.Database
|
||||||
query api.RoomserverQueryAPI
|
query api.RoomserverQueryAPI
|
||||||
alias api.RoomserverAliasAPI
|
alias api.RoomserverAliasAPI
|
||||||
|
|
@ -46,7 +46,7 @@ type OutputRoomEventConsumer struct {
|
||||||
func NewOutputRoomEventConsumer(
|
func NewOutputRoomEventConsumer(
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
kafkaConsumer sarama.Consumer,
|
kafkaConsumer sarama.Consumer,
|
||||||
store *accounts.Database,
|
store accounts.Database,
|
||||||
appserviceDB *storage.Database,
|
appserviceDB *storage.Database,
|
||||||
queryAPI api.RoomserverQueryAPI,
|
queryAPI api.RoomserverQueryAPI,
|
||||||
aliasAPI api.RoomserverAliasAPI,
|
aliasAPI api.RoomserverAliasAPI,
|
||||||
|
|
|
||||||
|
|
@ -36,9 +36,9 @@ const pathPrefixApp = "/_matrix/app/v1"
|
||||||
// applied:
|
// applied:
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func Setup(
|
func Setup(
|
||||||
apiMux *mux.Router, cfg config.Dendrite, // nolint: unparam
|
apiMux *mux.Router, cfg *config.Dendrite, // nolint: unparam
|
||||||
queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam
|
queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam
|
||||||
accountDB *accounts.Database, // nolint: unparam
|
accountDB accounts.Database, // nolint: unparam
|
||||||
federation *gomatrixserverlib.FederationClient, // nolint: unparam
|
federation *gomatrixserverlib.FederationClient, // nolint: unparam
|
||||||
transactionsCache *transactions.Cache, // nolint: unparam
|
transactionsCache *transactions.Cache, // nolint: unparam
|
||||||
) {
|
) {
|
||||||
|
|
|
||||||
141
clientapi/auth/storage/accounts/postgres/account_data_table.go
Normal file
141
clientapi/auth/storage/accounts/postgres/account_data_table.go
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const accountDataSchema = `
|
||||||
|
-- Stores data about accounts data.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_data (
|
||||||
|
-- The Matrix user ID localpart for this account
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- The room ID for this data (empty string if not specific to a room)
|
||||||
|
room_id TEXT,
|
||||||
|
-- The account data type
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
-- The account data content
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY(localpart, room_id, type)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertAccountDataSQL = `
|
||||||
|
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectAccountDataSQL = "" +
|
||||||
|
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||||
|
|
||||||
|
const selectAccountDataByTypeSQL = "" +
|
||||||
|
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||||
|
|
||||||
|
type accountDataStatements struct {
|
||||||
|
insertAccountDataStmt *sql.Stmt
|
||||||
|
selectAccountDataStmt *sql.Stmt
|
||||||
|
selectAccountDataByTypeStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(accountDataSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountDataStatements) insertAccountData(
|
||||||
|
ctx context.Context, localpart, roomID, dataType, content string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := s.insertAccountDataStmt
|
||||||
|
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountDataStatements) selectAccountData(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (
|
||||||
|
global []gomatrixserverlib.ClientEvent,
|
||||||
|
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
global = []gomatrixserverlib.ClientEvent{}
|
||||||
|
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
var dataType string
|
||||||
|
var content []byte
|
||||||
|
|
||||||
|
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ac := gomatrixserverlib.ClientEvent{
|
||||||
|
Type: dataType,
|
||||||
|
Content: content,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(roomID) > 0 {
|
||||||
|
rooms[roomID] = append(rooms[roomID], ac)
|
||||||
|
} else {
|
||||||
|
global = append(global, ac)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return global, rooms, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountDataStatements) selectAccountDataByType(
|
||||||
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
|
) (data *gomatrixserverlib.ClientEvent, err error) {
|
||||||
|
stmt := s.selectAccountDataByTypeStmt
|
||||||
|
var content []byte
|
||||||
|
|
||||||
|
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data = &gomatrixserverlib.ClientEvent{
|
||||||
|
Type: dataType,
|
||||||
|
Content: content,
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package accounts
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package accounts
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package accounts
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
@ -122,11 +122,10 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var m authtypes.Membership
|
var m authtypes.Membership
|
||||||
m.Localpart = localpart
|
m.Localpart = localpart
|
||||||
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil {
|
if err = rows.Scan(&m.RoomID, &m.EventID); err != nil {
|
||||||
return nil, err
|
return
|
||||||
}
|
}
|
||||||
memberships = append(memberships, m)
|
memberships = append(memberships, m)
|
||||||
}
|
}
|
||||||
|
return memberships, rows.Err()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package accounts
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
392
clientapi/auth/storage/accounts/postgres/storage.go
Normal file
392
clientapi/auth/storage/accounts/postgres/storage.go
Normal file
|
|
@ -0,0 +1,392 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
// Import the postgres database driver.
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Database represents an account database
|
||||||
|
type Database struct {
|
||||||
|
db *sql.DB
|
||||||
|
common.PartitionOffsetStatements
|
||||||
|
accounts accountsStatements
|
||||||
|
profiles profilesStatements
|
||||||
|
memberships membershipStatements
|
||||||
|
accountDatas accountDataStatements
|
||||||
|
threepids threepidStatements
|
||||||
|
filter filterStatements
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabase creates a new accounts and profiles database
|
||||||
|
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||||
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
|
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
partitions := common.PartitionOffsetStatements{}
|
||||||
|
if err = partitions.Prepare(db, "account"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
a := accountsStatements{}
|
||||||
|
if err = a.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p := profilesStatements{}
|
||||||
|
if err = p.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m := membershipStatements{}
|
||||||
|
if err = m.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ac := accountDataStatements{}
|
||||||
|
if err = ac.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := threepidStatements{}
|
||||||
|
if err = t.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f := filterStatements{}
|
||||||
|
if err = f.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
|
func (d *Database) GetAccountByPassword(
|
||||||
|
ctx context.Context, localpart, plaintextPassword string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
|
func (d *Database) GetProfileByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Profile, error) {
|
||||||
|
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||||
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
|
func (d *Database) SetAvatarURL(
|
||||||
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
|
) error {
|
||||||
|
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
|
func (d *Database) SetDisplayName(
|
||||||
|
ctx context.Context, localpart string, displayName string,
|
||||||
|
) error {
|
||||||
|
return d.profiles.setDisplayName(ctx, localpart, displayName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
|
// account already exists, it will return nil, nil.
|
||||||
|
func (d *Database) CreateAccount(
|
||||||
|
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Generate a password hash if this is not a password-less user
|
||||||
|
hash := ""
|
||||||
|
if plaintextPassword != "" {
|
||||||
|
hash, err = hashPassword(plaintextPassword)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
|
||||||
|
if common.IsUniqueConstraintViolationErr(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
|
||||||
|
"global": {
|
||||||
|
"content": [],
|
||||||
|
"override": [],
|
||||||
|
"room": [],
|
||||||
|
"sender": [],
|
||||||
|
"underride": []
|
||||||
|
}
|
||||||
|
}`); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveMembership saves the user matching a given localpart as a member of a given
|
||||||
|
// room. It also stores the ID of the membership event.
|
||||||
|
// If a membership already exists between the user and the room, or if the
|
||||||
|
// insert fails, returns the SQL error
|
||||||
|
func (d *Database) saveMembership(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
|
||||||
|
) error {
|
||||||
|
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeMembershipsByEventIDs removes the memberships corresponding to the
|
||||||
|
// `join` membership events IDs in the eventIDs slice.
|
||||||
|
// If the removal fails, or if there is no membership to remove, returns an error
|
||||||
|
func (d *Database) removeMembershipsByEventIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) error {
|
||||||
|
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMemberships adds the "join" membership events included in a given state
|
||||||
|
// events array, and removes those which ID is included in a given array of events
|
||||||
|
// IDs. All of the process is run in a transaction, which commits only once/if every
|
||||||
|
// insertion and deletion has been successfully processed.
|
||||||
|
// Returns a SQL error if there was an issue with any part of the process
|
||||||
|
func (d *Database) UpdateMemberships(
|
||||||
|
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range eventsToAdd {
|
||||||
|
if err := d.newMembership(ctx, txn, event); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMembershipInRoomByLocalpart returns the membership for an user
|
||||||
|
// matching the given localpart if he is a member of the room matching roomID,
|
||||||
|
// if not sql.ErrNoRows is returned.
|
||||||
|
// If there was an issue during the retrieval, returns the SQL error
|
||||||
|
func (d *Database) GetMembershipInRoomByLocalpart(
|
||||||
|
ctx context.Context, localpart, roomID string,
|
||||||
|
) (authtypes.Membership, error) {
|
||||||
|
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMembershipsByLocalpart returns an array containing the memberships for all
|
||||||
|
// the rooms a user matching a given localpart is a member of
|
||||||
|
// If no membership match the given localpart, returns an empty array
|
||||||
|
// If there was an issue during the retrieval, returns the SQL error
|
||||||
|
func (d *Database) GetMembershipsByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (memberships []authtypes.Membership, err error) {
|
||||||
|
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMembership saves a new membership in the database.
|
||||||
|
// If the event isn't a valid m.room.member event with type `join`, does nothing.
|
||||||
|
// If an error occurred, returns the SQL error
|
||||||
|
func (d *Database) newMembership(
|
||||||
|
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
|
||||||
|
) error {
|
||||||
|
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||||
|
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We only want state events from local users
|
||||||
|
if string(serverName) != string(d.serverName) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
eventID := ev.EventID()
|
||||||
|
roomID := ev.RoomID()
|
||||||
|
membership, err := ev.Membership()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only "join" membership events can be considered as new memberships
|
||||||
|
if membership == gomatrixserverlib.Join {
|
||||||
|
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAccountData saves new account data for a given user and a given room.
|
||||||
|
// If the account data is not specific to a room, the room ID should be an empty string
|
||||||
|
// If an account data already exists for a given set (user, room, data type), it will
|
||||||
|
// update the corresponding row with the new content
|
||||||
|
// Returns a SQL error if there was an issue with the insertion/update
|
||||||
|
func (d *Database) SaveAccountData(
|
||||||
|
ctx context.Context, localpart, roomID, dataType, content string,
|
||||||
|
) error {
|
||||||
|
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountData returns account data related to a given localpart
|
||||||
|
// If no account data could be found, returns an empty arrays
|
||||||
|
// Returns an error if there was an issue with the retrieval
|
||||||
|
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||||
|
global []gomatrixserverlib.ClientEvent,
|
||||||
|
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountDataByType returns account data matching a given
|
||||||
|
// localpart, room ID and type.
|
||||||
|
// If no account data could be found, returns nil
|
||||||
|
// Returns an error if there was an issue with the retrieval
|
||||||
|
func (d *Database) GetAccountDataByType(
|
||||||
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
|
) (data *gomatrixserverlib.ClientEvent, err error) {
|
||||||
|
return d.accountDatas.selectAccountDataByType(
|
||||||
|
ctx, localpart, roomID, dataType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||||
|
func (d *Database) GetNewNumericLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
) (int64, error) {
|
||||||
|
return d.accounts.selectNewNumericLocalpart(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hashPassword(plaintext string) (hash string, err error) {
|
||||||
|
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
||||||
|
return string(hashBytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
|
// a third-party identifier which is already associated to a local user.
|
||||||
|
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
||||||
|
|
||||||
|
// SaveThreePIDAssociation saves the association between a third party identifier
|
||||||
|
// and a local Matrix user (identified by the user's ID's local part).
|
||||||
|
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func (d *Database) SaveThreePIDAssociation(
|
||||||
|
ctx context.Context, threepid, localpart, medium string,
|
||||||
|
) (err error) {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
user, err := d.threepids.selectLocalpartForThreePID(
|
||||||
|
ctx, txn, threepid, medium,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user) > 0 {
|
||||||
|
return Err3PIDInUse
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveThreePIDAssociation removes the association involving a given third-party
|
||||||
|
// identifier.
|
||||||
|
// If no association exists involving this third-party identifier, returns nothing.
|
||||||
|
// If there was a problem talking to the database, returns an error.
|
||||||
|
func (d *Database) RemoveThreePIDAssociation(
|
||||||
|
ctx context.Context, threepid string, medium string,
|
||||||
|
) (err error) {
|
||||||
|
return d.threepids.deleteThreePID(ctx, threepid, medium)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||||
|
// identifier.
|
||||||
|
// If no association involves the given third-party idenfitier, returns an empty
|
||||||
|
// string.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func (d *Database) GetLocalpartForThreePID(
|
||||||
|
ctx context.Context, threepid string, medium string,
|
||||||
|
) (localpart string, err error) {
|
||||||
|
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||||
|
// a given local user.
|
||||||
|
// If no association is known for this user, returns an empty slice.
|
||||||
|
// Returns an error if there was an issue talking to the database.
|
||||||
|
func (d *Database) GetThreePIDsForLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
|
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilter looks up the filter associated with a given local user and filter ID.
|
||||||
|
// Returns a filter structure. Otherwise returns an error if no such filter exists
|
||||||
|
// or if there was an error talking to the database.
|
||||||
|
func (d *Database) GetFilter(
|
||||||
|
ctx context.Context, localpart string, filterID string,
|
||||||
|
) (*gomatrixserverlib.Filter, error) {
|
||||||
|
return d.filter.selectFilter(ctx, localpart, filterID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutFilter puts the passed filter into the database.
|
||||||
|
// Returns the filterID as a string. Otherwise returns an error if something
|
||||||
|
// goes wrong.
|
||||||
|
func (d *Database) PutFilter(
|
||||||
|
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
|
||||||
|
) (string, error) {
|
||||||
|
return d.filter.insertFilter(ctx, filter, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckAccountAvailability checks if the username/localpart is already present
|
||||||
|
// in the database.
|
||||||
|
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||||
|
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||||
|
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||||
|
// This function assumes the request is authenticated or the account data is used only internally.
|
||||||
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
|
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package accounts
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package accounts
|
package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS account_data (
|
||||||
|
|
||||||
const insertAccountDataSQL = `
|
const insertAccountDataSQL = `
|
||||||
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||||
`
|
`
|
||||||
|
|
||||||
const selectAccountDataSQL = "" +
|
const selectAccountDataSQL = "" +
|
||||||
151
clientapi/auth/storage/accounts/sqlite3/accounts_table.go
Normal file
151
clientapi/auth/storage/accounts/sqlite3/accounts_table.go
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const accountsSchema = `
|
||||||
|
-- Stores data about accounts.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
|
-- The Matrix user ID localpart for this account
|
||||||
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||||
|
created_ts BIGINT NOT NULL,
|
||||||
|
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||||
|
password_hash TEXT,
|
||||||
|
-- Identifies which application service this account belongs to, if any.
|
||||||
|
appservice_id TEXT
|
||||||
|
-- TODO:
|
||||||
|
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertAccountSQL = "" +
|
||||||
|
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
|
const selectAccountByLocalpartSQL = "" +
|
||||||
|
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const selectPasswordHashSQL = "" +
|
||||||
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
|
"SELECT COUNT(localpart) FROM account_accounts"
|
||||||
|
|
||||||
|
// TODO: Update password
|
||||||
|
|
||||||
|
type accountsStatements struct {
|
||||||
|
insertAccountStmt *sql.Stmt
|
||||||
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
|
selectPasswordHashStmt *sql.Stmt
|
||||||
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
_, err = db.Exec(accountsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||||
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
|
// on success.
|
||||||
|
func (s *accountsStatements) insertAccount(
|
||||||
|
ctx context.Context, localpart, hash, appserviceID string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
|
stmt := s.insertAccountStmt
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if appserviceID == "" {
|
||||||
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||||
|
} else {
|
||||||
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &authtypes.Account{
|
||||||
|
Localpart: localpart,
|
||||||
|
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||||
|
ServerName: s.serverName,
|
||||||
|
AppServiceID: appserviceID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountsStatements) selectPasswordHash(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (hash string, err error) {
|
||||||
|
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
var appserviceIDPtr sql.NullString
|
||||||
|
var acc authtypes.Account
|
||||||
|
|
||||||
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
|
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
|
||||||
|
if err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if appserviceIDPtr.Valid {
|
||||||
|
acc.AppServiceID = appserviceIDPtr.String
|
||||||
|
}
|
||||||
|
|
||||||
|
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
acc.ServerName = s.serverName
|
||||||
|
|
||||||
|
return &acc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
) (id int64, err error) {
|
||||||
|
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
|
||||||
|
return
|
||||||
|
}
|
||||||
139
clientapi/auth/storage/accounts/sqlite3/filter_table.go
Normal file
139
clientapi/auth/storage/accounts/sqlite3/filter_table.go
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
// Copyright 2017 Jan Christian Grünhage
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const filterSchema = `
|
||||||
|
-- Stores data about filters
|
||||||
|
CREATE TABLE IF NOT EXISTS account_filter (
|
||||||
|
-- The filter
|
||||||
|
filter TEXT NOT NULL,
|
||||||
|
-- The ID
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
-- The localpart of the Matrix user ID associated to this filter
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
|
||||||
|
UNIQUE (id, localpart)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectFilterSQL = "" +
|
||||||
|
"SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2"
|
||||||
|
|
||||||
|
const selectFilterIDByContentSQL = "" +
|
||||||
|
"SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2"
|
||||||
|
|
||||||
|
const insertFilterSQL = "" +
|
||||||
|
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
|
||||||
|
|
||||||
|
const selectLastInsertedFilterIDSQL = "" +
|
||||||
|
"SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
|
||||||
|
|
||||||
|
type filterStatements struct {
|
||||||
|
selectFilterStmt *sql.Stmt
|
||||||
|
selectLastInsertedFilterIDStmt *sql.Stmt
|
||||||
|
selectFilterIDByContentStmt *sql.Stmt
|
||||||
|
insertFilterStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(filterSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *filterStatements) selectFilter(
|
||||||
|
ctx context.Context, localpart string, filterID string,
|
||||||
|
) (*gomatrixserverlib.Filter, error) {
|
||||||
|
// Retrieve filter from database (stored as canonical JSON)
|
||||||
|
var filterData []byte
|
||||||
|
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal JSON into Filter struct
|
||||||
|
var filter gomatrixserverlib.Filter
|
||||||
|
if err = json.Unmarshal(filterData, &filter); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &filter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *filterStatements) insertFilter(
|
||||||
|
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
|
||||||
|
) (filterID string, err error) {
|
||||||
|
var existingFilterID string
|
||||||
|
|
||||||
|
// Serialise json
|
||||||
|
filterJSON, err := json.Marshal(filter)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Remove whitespaces and sort JSON data
|
||||||
|
// needed to prevent from inserting the same filter multiple times
|
||||||
|
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if filter already exists in the database using its localpart and content
|
||||||
|
//
|
||||||
|
// This can result in a race condition when two clients try to insert the
|
||||||
|
// same filter and localpart at the same time, however this is not a
|
||||||
|
// problem as both calls will result in the same filterID
|
||||||
|
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||||
|
localpart, filterJSON).Scan(&existingFilterID)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// If it does, return the existing ID
|
||||||
|
if existingFilterID != "" {
|
||||||
|
return existingFilterID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise insert the filter and return the new ID
|
||||||
|
if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
|
||||||
|
if err := row.Scan(&filterID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
131
clientapi/auth/storage/accounts/sqlite3/membership_table.go
Normal file
131
clientapi/auth/storage/accounts/sqlite3/membership_table.go
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
const membershipSchema = `
|
||||||
|
-- Stores data about users memberships to rooms.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_memberships (
|
||||||
|
-- The Matrix user ID localpart for the member
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- The room this user is a member of
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
-- The ID of the join membership event
|
||||||
|
event_id TEXT NOT NULL,
|
||||||
|
|
||||||
|
-- A user can only be member of a room once
|
||||||
|
PRIMARY KEY (localpart, room_id),
|
||||||
|
|
||||||
|
UNIQUE (event_id)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertMembershipSQL = `
|
||||||
|
INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectMembershipsByLocalpartSQL = "" +
|
||||||
|
"SELECT room_id, event_id FROM account_memberships WHERE localpart = $1"
|
||||||
|
|
||||||
|
const selectMembershipInRoomByLocalpartSQL = "" +
|
||||||
|
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
|
||||||
|
|
||||||
|
const deleteMembershipsByEventIDsSQL = "" +
|
||||||
|
"DELETE FROM account_memberships WHERE event_id IN ($1)"
|
||||||
|
|
||||||
|
type membershipStatements struct {
|
||||||
|
deleteMembershipsByEventIDsStmt *sql.Stmt
|
||||||
|
insertMembershipStmt *sql.Stmt
|
||||||
|
selectMembershipInRoomByLocalpartStmt *sql.Stmt
|
||||||
|
selectMembershipsByLocalpartStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(membershipSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) insertMembership(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := txn.Stmt(s.insertMembershipStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) deleteMembershipsByEventIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) selectMembershipInRoomByLocalpart(
|
||||||
|
ctx context.Context, localpart, roomID string,
|
||||||
|
) (authtypes.Membership, error) {
|
||||||
|
membership := authtypes.Membership{Localpart: localpart, RoomID: roomID}
|
||||||
|
stmt := s.selectMembershipInRoomByLocalpartStmt
|
||||||
|
err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID)
|
||||||
|
|
||||||
|
return membership, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) selectMembershipsByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (memberships []authtypes.Membership, err error) {
|
||||||
|
stmt := s.selectMembershipsByLocalpartStmt
|
||||||
|
rows, err := stmt.QueryContext(ctx, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
memberships = []authtypes.Membership{}
|
||||||
|
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
for rows.Next() {
|
||||||
|
var m authtypes.Membership
|
||||||
|
m.Localpart = localpart
|
||||||
|
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
memberships = append(memberships, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
107
clientapi/auth/storage/accounts/sqlite3/profile_table.go
Normal file
107
clientapi/auth/storage/accounts/sqlite3/profile_table.go
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
const profilesSchema = `
|
||||||
|
-- Stores data about accounts profiles.
|
||||||
|
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||||
|
-- The Matrix user ID localpart for this account
|
||||||
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The display name for this account
|
||||||
|
display_name TEXT,
|
||||||
|
-- The URL of the avatar for this account
|
||||||
|
avatar_url TEXT
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertProfileSQL = "" +
|
||||||
|
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||||
|
|
||||||
|
const selectProfileByLocalpartSQL = "" +
|
||||||
|
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||||
|
|
||||||
|
const setAvatarURLSQL = "" +
|
||||||
|
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const setDisplayNameSQL = "" +
|
||||||
|
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
type profilesStatements struct {
|
||||||
|
insertProfileStmt *sql.Stmt
|
||||||
|
selectProfileByLocalpartStmt *sql.Stmt
|
||||||
|
setAvatarURLStmt *sql.Stmt
|
||||||
|
setDisplayNameStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(profilesSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *profilesStatements) insertProfile(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *profilesStatements) selectProfileByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Profile, error) {
|
||||||
|
var profile authtypes.Profile
|
||||||
|
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||||
|
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &profile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *profilesStatements) setAvatarURL(
|
||||||
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *profilesStatements) setDisplayName(
|
||||||
|
ctx context.Context, localpart string, displayName string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
|
||||||
|
return
|
||||||
|
}
|
||||||
392
clientapi/auth/storage/accounts/sqlite3/storage.go
Normal file
392
clientapi/auth/storage/accounts/sqlite3/storage.go
Normal file
|
|
@ -0,0 +1,392 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
// Import the postgres database driver.
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Database represents an account database
|
||||||
|
type Database struct {
|
||||||
|
db *sql.DB
|
||||||
|
common.PartitionOffsetStatements
|
||||||
|
accounts accountsStatements
|
||||||
|
profiles profilesStatements
|
||||||
|
memberships membershipStatements
|
||||||
|
accountDatas accountDataStatements
|
||||||
|
threepids threepidStatements
|
||||||
|
filter filterStatements
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabase creates a new accounts and profiles database
|
||||||
|
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||||
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
|
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
partitions := common.PartitionOffsetStatements{}
|
||||||
|
if err = partitions.Prepare(db, "account"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
a := accountsStatements{}
|
||||||
|
if err = a.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p := profilesStatements{}
|
||||||
|
if err = p.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m := membershipStatements{}
|
||||||
|
if err = m.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ac := accountDataStatements{}
|
||||||
|
if err = ac.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := threepidStatements{}
|
||||||
|
if err = t.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f := filterStatements{}
|
||||||
|
if err = f.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
|
func (d *Database) GetAccountByPassword(
|
||||||
|
ctx context.Context, localpart, plaintextPassword string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
|
func (d *Database) GetProfileByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Profile, error) {
|
||||||
|
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||||
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
|
func (d *Database) SetAvatarURL(
|
||||||
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
|
) error {
|
||||||
|
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
|
func (d *Database) SetDisplayName(
|
||||||
|
ctx context.Context, localpart string, displayName string,
|
||||||
|
) error {
|
||||||
|
return d.profiles.setDisplayName(ctx, localpart, displayName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
|
// account already exists, it will return nil, nil.
|
||||||
|
func (d *Database) CreateAccount(
|
||||||
|
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Generate a password hash if this is not a password-less user
|
||||||
|
hash := ""
|
||||||
|
if plaintextPassword != "" {
|
||||||
|
hash, err = hashPassword(plaintextPassword)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
|
||||||
|
if common.IsUniqueConstraintViolationErr(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
|
||||||
|
"global": {
|
||||||
|
"content": [],
|
||||||
|
"override": [],
|
||||||
|
"room": [],
|
||||||
|
"sender": [],
|
||||||
|
"underride": []
|
||||||
|
}
|
||||||
|
}`); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveMembership saves the user matching a given localpart as a member of a given
|
||||||
|
// room. It also stores the ID of the membership event.
|
||||||
|
// If a membership already exists between the user and the room, or if the
|
||||||
|
// insert fails, returns the SQL error
|
||||||
|
func (d *Database) saveMembership(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
|
||||||
|
) error {
|
||||||
|
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeMembershipsByEventIDs removes the memberships corresponding to the
|
||||||
|
// `join` membership events IDs in the eventIDs slice.
|
||||||
|
// If the removal fails, or if there is no membership to remove, returns an error
|
||||||
|
func (d *Database) removeMembershipsByEventIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) error {
|
||||||
|
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMemberships adds the "join" membership events included in a given state
|
||||||
|
// events array, and removes those which ID is included in a given array of events
|
||||||
|
// IDs. All of the process is run in a transaction, which commits only once/if every
|
||||||
|
// insertion and deletion has been successfully processed.
|
||||||
|
// Returns a SQL error if there was an issue with any part of the process
|
||||||
|
func (d *Database) UpdateMemberships(
|
||||||
|
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range eventsToAdd {
|
||||||
|
if err := d.newMembership(ctx, txn, event); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMembershipInRoomByLocalpart returns the membership for an user
|
||||||
|
// matching the given localpart if he is a member of the room matching roomID,
|
||||||
|
// if not sql.ErrNoRows is returned.
|
||||||
|
// If there was an issue during the retrieval, returns the SQL error
|
||||||
|
func (d *Database) GetMembershipInRoomByLocalpart(
|
||||||
|
ctx context.Context, localpart, roomID string,
|
||||||
|
) (authtypes.Membership, error) {
|
||||||
|
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMembershipsByLocalpart returns an array containing the memberships for all
|
||||||
|
// the rooms a user matching a given localpart is a member of
|
||||||
|
// If no membership match the given localpart, returns an empty array
|
||||||
|
// If there was an issue during the retrieval, returns the SQL error
|
||||||
|
func (d *Database) GetMembershipsByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (memberships []authtypes.Membership, err error) {
|
||||||
|
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMembership saves a new membership in the database.
|
||||||
|
// If the event isn't a valid m.room.member event with type `join`, does nothing.
|
||||||
|
// If an error occurred, returns the SQL error
|
||||||
|
func (d *Database) newMembership(
|
||||||
|
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
|
||||||
|
) error {
|
||||||
|
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||||
|
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We only want state events from local users
|
||||||
|
if string(serverName) != string(d.serverName) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
eventID := ev.EventID()
|
||||||
|
roomID := ev.RoomID()
|
||||||
|
membership, err := ev.Membership()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only "join" membership events can be considered as new memberships
|
||||||
|
if membership == gomatrixserverlib.Join {
|
||||||
|
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAccountData saves new account data for a given user and a given room.
|
||||||
|
// If the account data is not specific to a room, the room ID should be an empty string
|
||||||
|
// If an account data already exists for a given set (user, room, data type), it will
|
||||||
|
// update the corresponding row with the new content
|
||||||
|
// Returns a SQL error if there was an issue with the insertion/update
|
||||||
|
func (d *Database) SaveAccountData(
|
||||||
|
ctx context.Context, localpart, roomID, dataType, content string,
|
||||||
|
) error {
|
||||||
|
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountData returns account data related to a given localpart
|
||||||
|
// If no account data could be found, returns an empty arrays
|
||||||
|
// Returns an error if there was an issue with the retrieval
|
||||||
|
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||||
|
global []gomatrixserverlib.ClientEvent,
|
||||||
|
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountDataByType returns account data matching a given
|
||||||
|
// localpart, room ID and type.
|
||||||
|
// If no account data could be found, returns nil
|
||||||
|
// Returns an error if there was an issue with the retrieval
|
||||||
|
func (d *Database) GetAccountDataByType(
|
||||||
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
|
) (data *gomatrixserverlib.ClientEvent, err error) {
|
||||||
|
return d.accountDatas.selectAccountDataByType(
|
||||||
|
ctx, localpart, roomID, dataType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||||
|
func (d *Database) GetNewNumericLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
) (int64, error) {
|
||||||
|
return d.accounts.selectNewNumericLocalpart(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hashPassword(plaintext string) (hash string, err error) {
|
||||||
|
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
||||||
|
return string(hashBytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
|
// a third-party identifier which is already associated to a local user.
|
||||||
|
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
||||||
|
|
||||||
|
// SaveThreePIDAssociation saves the association between a third party identifier
|
||||||
|
// and a local Matrix user (identified by the user's ID's local part).
|
||||||
|
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func (d *Database) SaveThreePIDAssociation(
|
||||||
|
ctx context.Context, threepid, localpart, medium string,
|
||||||
|
) (err error) {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
user, err := d.threepids.selectLocalpartForThreePID(
|
||||||
|
ctx, txn, threepid, medium,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user) > 0 {
|
||||||
|
return Err3PIDInUse
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveThreePIDAssociation removes the association involving a given third-party
|
||||||
|
// identifier.
|
||||||
|
// If no association exists involving this third-party identifier, returns nothing.
|
||||||
|
// If there was a problem talking to the database, returns an error.
|
||||||
|
func (d *Database) RemoveThreePIDAssociation(
|
||||||
|
ctx context.Context, threepid string, medium string,
|
||||||
|
) (err error) {
|
||||||
|
return d.threepids.deleteThreePID(ctx, threepid, medium)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||||
|
// identifier.
|
||||||
|
// If no association involves the given third-party idenfitier, returns an empty
|
||||||
|
// string.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func (d *Database) GetLocalpartForThreePID(
|
||||||
|
ctx context.Context, threepid string, medium string,
|
||||||
|
) (localpart string, err error) {
|
||||||
|
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||||
|
// a given local user.
|
||||||
|
// If no association is known for this user, returns an empty slice.
|
||||||
|
// Returns an error if there was an issue talking to the database.
|
||||||
|
func (d *Database) GetThreePIDsForLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
|
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilter looks up the filter associated with a given local user and filter ID.
|
||||||
|
// Returns a filter structure. Otherwise returns an error if no such filter exists
|
||||||
|
// or if there was an error talking to the database.
|
||||||
|
func (d *Database) GetFilter(
|
||||||
|
ctx context.Context, localpart string, filterID string,
|
||||||
|
) (*gomatrixserverlib.Filter, error) {
|
||||||
|
return d.filter.selectFilter(ctx, localpart, filterID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutFilter puts the passed filter into the database.
|
||||||
|
// Returns the filterID as a string. Otherwise returns an error if something
|
||||||
|
// goes wrong.
|
||||||
|
func (d *Database) PutFilter(
|
||||||
|
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
|
||||||
|
) (string, error) {
|
||||||
|
return d.filter.insertFilter(ctx, filter, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckAccountAvailability checks if the username/localpart is already present
|
||||||
|
// in the database.
|
||||||
|
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||||
|
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||||
|
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||||
|
// This function assumes the request is authenticated or the account data is used only internally.
|
||||||
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
|
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
129
clientapi/auth/storage/accounts/sqlite3/threepid_table.go
Normal file
129
clientapi/auth/storage/accounts/sqlite3/threepid_table.go
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
const threepidSchema = `
|
||||||
|
-- Stores data about third party identifiers
|
||||||
|
CREATE TABLE IF NOT EXISTS account_threepid (
|
||||||
|
-- The third party identifier
|
||||||
|
threepid TEXT NOT NULL,
|
||||||
|
-- The 3PID medium
|
||||||
|
medium TEXT NOT NULL DEFAULT 'email',
|
||||||
|
-- The localpart of the Matrix user ID associated to this 3PID
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY(threepid, medium)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectLocalpartForThreePIDSQL = "" +
|
||||||
|
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||||
|
|
||||||
|
const selectThreePIDsForLocalpartSQL = "" +
|
||||||
|
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||||
|
|
||||||
|
const insertThreePIDSQL = "" +
|
||||||
|
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||||
|
|
||||||
|
const deleteThreePIDSQL = "" +
|
||||||
|
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||||
|
|
||||||
|
type threepidStatements struct {
|
||||||
|
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||||
|
selectThreePIDsForLocalpartStmt *sql.Stmt
|
||||||
|
insertThreePIDStmt *sql.Stmt
|
||||||
|
deleteThreePIDStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(threepidSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *threepidStatements) selectLocalpartForThreePID(
|
||||||
|
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||||
|
) (localpart string, err error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
|
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
threepids = []authtypes.ThreePID{}
|
||||||
|
for rows.Next() {
|
||||||
|
var threepid string
|
||||||
|
var medium string
|
||||||
|
if err = rows.Scan(&threepid, &medium); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
threepids = append(threepids, authtypes.ThreePID{
|
||||||
|
Address: threepid,
|
||||||
|
Medium: medium,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return threepids, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *threepidStatements) insertThreePID(
|
||||||
|
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := common.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *threepidStatements) deleteThreePID(
|
||||||
|
ctx context.Context, threepid string, medium string) (err error) {
|
||||||
|
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -1,392 +1,56 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3"
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
|
|
||||||
// Import the postgres database driver.
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database represents an account database
|
type Database interface {
|
||||||
type Database struct {
|
common.PartitionStorer
|
||||||
db *sql.DB
|
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error)
|
||||||
common.PartitionOffsetStatements
|
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||||
accounts accountsStatements
|
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
||||||
profiles profilesStatements
|
SetDisplayName(ctx context.Context, localpart string, displayName string) error
|
||||||
memberships membershipStatements
|
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
|
||||||
accountDatas accountDataStatements
|
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
|
||||||
threepids threepidStatements
|
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
|
||||||
filter filterStatements
|
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
|
||||||
serverName gomatrixserverlib.ServerName
|
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
|
||||||
|
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
|
||||||
|
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
|
||||||
|
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||||
|
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||||
|
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||||
|
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||||
|
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||||
|
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
|
||||||
|
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
|
||||||
|
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||||
|
GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
|
||||||
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
uri, err := url.Parse(dataSourceName)
|
||||||
var db *sql.DB
|
|
||||||
var err error
|
|
||||||
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
partitions := common.PartitionOffsetStatements{}
|
|
||||||
if err = partitions.Prepare(db, "account"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
a := accountsStatements{}
|
|
||||||
if err = a.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
p := profilesStatements{}
|
|
||||||
if err = p.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m := membershipStatements{}
|
|
||||||
if err = m.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ac := accountDataStatements{}
|
|
||||||
if err = ac.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t := threepidStatements{}
|
|
||||||
if err = t.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
f := filterStatements{}
|
|
||||||
if err = f.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
|
||||||
func (d *Database) GetAccountByPassword(
|
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
|
||||||
) (*authtypes.Account, error) {
|
|
||||||
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return postgres.NewDatabase(dataSourceName, serverName)
|
||||||
}
|
}
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
switch uri.Scheme {
|
||||||
return nil, err
|
case "postgres":
|
||||||
|
return postgres.NewDatabase(dataSourceName, serverName)
|
||||||
|
case "file":
|
||||||
|
return sqlite3.NewDatabase(dataSourceName, serverName)
|
||||||
|
default:
|
||||||
|
return postgres.NewDatabase(dataSourceName, serverName)
|
||||||
}
|
}
|
||||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
|
||||||
func (d *Database) GetProfileByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) (*authtypes.Profile, error) {
|
|
||||||
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
|
||||||
func (d *Database) SetAvatarURL(
|
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
|
||||||
) error {
|
|
||||||
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
|
||||||
func (d *Database) SetDisplayName(
|
|
||||||
ctx context.Context, localpart string, displayName string,
|
|
||||||
) error {
|
|
||||||
return d.profiles.setDisplayName(ctx, localpart, displayName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
|
||||||
// account already exists, it will return nil, nil.
|
|
||||||
func (d *Database) CreateAccount(
|
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
|
||||||
) (*authtypes.Account, error) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// Generate a password hash if this is not a password-less user
|
|
||||||
hash := ""
|
|
||||||
if plaintextPassword != "" {
|
|
||||||
hash, err = hashPassword(plaintextPassword)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
|
|
||||||
if common.IsUniqueConstraintViolationErr(err) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
|
|
||||||
"global": {
|
|
||||||
"content": [],
|
|
||||||
"override": [],
|
|
||||||
"room": [],
|
|
||||||
"sender": [],
|
|
||||||
"underride": []
|
|
||||||
}
|
|
||||||
}`); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveMembership saves the user matching a given localpart as a member of a given
|
|
||||||
// room. It also stores the ID of the membership event.
|
|
||||||
// If a membership already exists between the user and the room, or if the
|
|
||||||
// insert fails, returns the SQL error
|
|
||||||
func (d *Database) saveMembership(
|
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
|
|
||||||
) error {
|
|
||||||
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeMembershipsByEventIDs removes the memberships corresponding to the
|
|
||||||
// `join` membership events IDs in the eventIDs slice.
|
|
||||||
// If the removal fails, or if there is no membership to remove, returns an error
|
|
||||||
func (d *Database) removeMembershipsByEventIDs(
|
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
|
||||||
) error {
|
|
||||||
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateMemberships adds the "join" membership events included in a given state
|
|
||||||
// events array, and removes those which ID is included in a given array of events
|
|
||||||
// IDs. All of the process is run in a transaction, which commits only once/if every
|
|
||||||
// insertion and deletion has been successfully processed.
|
|
||||||
// Returns a SQL error if there was an issue with any part of the process
|
|
||||||
func (d *Database) UpdateMemberships(
|
|
||||||
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
|
|
||||||
) error {
|
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range eventsToAdd {
|
|
||||||
if err := d.newMembership(ctx, txn, event); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMembershipInRoomByLocalpart returns the membership for an user
|
|
||||||
// matching the given localpart if he is a member of the room matching roomID,
|
|
||||||
// if not sql.ErrNoRows is returned.
|
|
||||||
// If there was an issue during the retrieval, returns the SQL error
|
|
||||||
func (d *Database) GetMembershipInRoomByLocalpart(
|
|
||||||
ctx context.Context, localpart, roomID string,
|
|
||||||
) (authtypes.Membership, error) {
|
|
||||||
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMembershipsByLocalpart returns an array containing the memberships for all
|
|
||||||
// the rooms a user matching a given localpart is a member of
|
|
||||||
// If no membership match the given localpart, returns an empty array
|
|
||||||
// If there was an issue during the retrieval, returns the SQL error
|
|
||||||
func (d *Database) GetMembershipsByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) (memberships []authtypes.Membership, err error) {
|
|
||||||
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newMembership saves a new membership in the database.
|
|
||||||
// If the event isn't a valid m.room.member event with type `join`, does nothing.
|
|
||||||
// If an error occurred, returns the SQL error
|
|
||||||
func (d *Database) newMembership(
|
|
||||||
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
|
|
||||||
) error {
|
|
||||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
|
||||||
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We only want state events from local users
|
|
||||||
if string(serverName) != string(d.serverName) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
eventID := ev.EventID()
|
|
||||||
roomID := ev.RoomID()
|
|
||||||
membership, err := ev.Membership()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only "join" membership events can be considered as new memberships
|
|
||||||
if membership == gomatrixserverlib.Join {
|
|
||||||
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveAccountData saves new account data for a given user and a given room.
|
|
||||||
// If the account data is not specific to a room, the room ID should be an empty string
|
|
||||||
// If an account data already exists for a given set (user, room, data type), it will
|
|
||||||
// update the corresponding row with the new content
|
|
||||||
// Returns a SQL error if there was an issue with the insertion/update
|
|
||||||
func (d *Database) SaveAccountData(
|
|
||||||
ctx context.Context, localpart, roomID, dataType, content string,
|
|
||||||
) error {
|
|
||||||
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountData returns account data related to a given localpart
|
|
||||||
// If no account data could be found, returns an empty arrays
|
|
||||||
// Returns an error if there was an issue with the retrieval
|
|
||||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|
||||||
global []gomatrixserverlib.ClientEvent,
|
|
||||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountDataByType returns account data matching a given
|
|
||||||
// localpart, room ID and type.
|
|
||||||
// If no account data could be found, returns nil
|
|
||||||
// Returns an error if there was an issue with the retrieval
|
|
||||||
func (d *Database) GetAccountDataByType(
|
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
|
||||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
|
||||||
return d.accountDatas.selectAccountDataByType(
|
|
||||||
ctx, localpart, roomID, dataType,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
|
||||||
func (d *Database) GetNewNumericLocalpart(
|
|
||||||
ctx context.Context,
|
|
||||||
) (int64, error) {
|
|
||||||
return d.accounts.selectNewNumericLocalpart(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func hashPassword(plaintext string) (hash string, err error) {
|
|
||||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
|
||||||
return string(hashBytes), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
// a third-party identifier which is already associated to a local user.
|
// a third-party identifier which is already associated to a local user.
|
||||||
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
||||||
|
|
||||||
// SaveThreePIDAssociation saves the association between a third party identifier
|
|
||||||
// and a local Matrix user (identified by the user's ID's local part).
|
|
||||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
|
||||||
// Returns an error if there was a problem talking to the database.
|
|
||||||
func (d *Database) SaveThreePIDAssociation(
|
|
||||||
ctx context.Context, threepid, localpart, medium string,
|
|
||||||
) (err error) {
|
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
user, err := d.threepids.selectLocalpartForThreePID(
|
|
||||||
ctx, txn, threepid, medium,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(user) > 0 {
|
|
||||||
return Err3PIDInUse
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveThreePIDAssociation removes the association involving a given third-party
|
|
||||||
// identifier.
|
|
||||||
// If no association exists involving this third-party identifier, returns nothing.
|
|
||||||
// If there was a problem talking to the database, returns an error.
|
|
||||||
func (d *Database) RemoveThreePIDAssociation(
|
|
||||||
ctx context.Context, threepid string, medium string,
|
|
||||||
) (err error) {
|
|
||||||
return d.threepids.deleteThreePID(ctx, threepid, medium)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
|
||||||
// identifier.
|
|
||||||
// If no association involves the given third-party idenfitier, returns an empty
|
|
||||||
// string.
|
|
||||||
// Returns an error if there was a problem talking to the database.
|
|
||||||
func (d *Database) GetLocalpartForThreePID(
|
|
||||||
ctx context.Context, threepid string, medium string,
|
|
||||||
) (localpart string, err error) {
|
|
||||||
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
|
||||||
// a given local user.
|
|
||||||
// If no association is known for this user, returns an empty slice.
|
|
||||||
// Returns an error if there was an issue talking to the database.
|
|
||||||
func (d *Database) GetThreePIDsForLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
|
||||||
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFilter looks up the filter associated with a given local user and filter ID.
|
|
||||||
// Returns a filter structure. Otherwise returns an error if no such filter exists
|
|
||||||
// or if there was an error talking to the database.
|
|
||||||
func (d *Database) GetFilter(
|
|
||||||
ctx context.Context, localpart string, filterID string,
|
|
||||||
) (*gomatrixserverlib.Filter, error) {
|
|
||||||
return d.filter.selectFilter(ctx, localpart, filterID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PutFilter puts the passed filter into the database.
|
|
||||||
// Returns the filterID as a string. Otherwise returns an error if something
|
|
||||||
// goes wrong.
|
|
||||||
func (d *Database) PutFilter(
|
|
||||||
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
|
|
||||||
) (string, error) {
|
|
||||||
return d.filter.insertFilter(ctx, filter, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckAccountAvailability checks if the username/localpart is already present
|
|
||||||
// in the database.
|
|
||||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
|
||||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
|
||||||
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
|
||||||
// This function assumes the request is authenticated or the account data is used only internally.
|
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
|
||||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
|
||||||
) (*authtypes.Account, error) {
|
|
||||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -12,17 +12,17 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package devices
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/lib/pq"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -80,6 +80,9 @@ const deleteDeviceSQL = "" +
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1"
|
"DELETE FROM device_devices WHERE localpart = $1"
|
||||||
|
|
||||||
|
const deleteDevicesSQL = "" +
|
||||||
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
|
|
@ -88,6 +91,7 @@ type devicesStatements struct {
|
||||||
updateDeviceNameStmt *sql.Stmt
|
updateDeviceNameStmt *sql.Stmt
|
||||||
deleteDeviceStmt *sql.Stmt
|
deleteDeviceStmt *sql.Stmt
|
||||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||||
|
deleteDevicesStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,6 +121,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
||||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
s.serverName = server
|
s.serverName = server
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -142,6 +149,7 @@ func (s *devicesStatements) insertDevice(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deleteDevice removes a single device by id and user localpart.
|
||||||
func (s *devicesStatements) deleteDevice(
|
func (s *devicesStatements) deleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||||
) error {
|
) error {
|
||||||
|
|
@ -150,6 +158,18 @@ func (s *devicesStatements) deleteDevice(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
||||||
|
// Returns an error if the execution failed.
|
||||||
|
func (s *devicesStatements) deleteDevices(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.deleteDevicesStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteDevicesByLocalpart removes all devices for the
|
||||||
|
// given user localpart.
|
||||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
) error {
|
) error {
|
||||||
|
|
@ -206,6 +226,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
}
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dev authtypes.Device
|
var dev authtypes.Device
|
||||||
|
|
@ -217,5 +238,5 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
return devices, nil
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
182
clientapi/auth/storage/devices/postgres/storage.go
Normal file
182
clientapi/auth/storage/devices/postgres/storage.go
Normal file
|
|
@ -0,0 +1,182 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The length of generated device IDs
|
||||||
|
var deviceIDByteLength = 6
|
||||||
|
|
||||||
|
// Database represents a device database.
|
||||||
|
type Database struct {
|
||||||
|
db *sql.DB
|
||||||
|
devices devicesStatements
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabase creates a new device database
|
||||||
|
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||||
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
|
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d := devicesStatements{}
|
||||||
|
if err = d.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{db, d}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByAccessToken(
|
||||||
|
ctx context.Context, token string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
|
return d.devices.selectDeviceByToken(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceByID returns the device matching the given ID.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByID(
|
||||||
|
ctx context.Context, localpart, deviceID string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
|
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||||
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) ([]authtypes.Device, error) {
|
||||||
|
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
// an error will be returned.
|
||||||
|
// If no device ID is given one is generated.
|
||||||
|
// Returns the device on success.
|
||||||
|
func (d *Database) CreateDevice(
|
||||||
|
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||||
|
displayName *string,
|
||||||
|
) (dev *authtypes.Device, returnErr error) {
|
||||||
|
if deviceID != nil {
|
||||||
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
// Revoke existing tokens for this device
|
||||||
|
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// We generate device IDs in a loop in case its already taken.
|
||||||
|
// We cap this at going round 5 times to ensure we don't spin forever
|
||||||
|
var newDeviceID string
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
newDeviceID, returnErr = generateDeviceID()
|
||||||
|
if returnErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if returnErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||||
|
// random bytes.
|
||||||
|
func generateDeviceID() (string, error) {
|
||||||
|
b := make([]byte, deviceIDByteLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// url-safe no padding
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDevice updates the given device with the display name.
|
||||||
|
// Returns SQL error if there are problems and nil on success.
|
||||||
|
func (d *Database) UpdateDevice(
|
||||||
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevice revokes a device by deleting the entry in the database
|
||||||
|
// matching with the given device ID and user ID localpart.
|
||||||
|
// If the device doesn't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevice(
|
||||||
|
ctx context.Context, deviceID, localpart string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||||
|
// matching with the given device IDs and user ID localpart.
|
||||||
|
// If the devices don't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevices(
|
||||||
|
ctx context.Context, localpart string, devices []string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||||
|
// database matching the given user ID localpart.
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveAllDevices(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
243
clientapi/auth/storage/devices/sqlite3/devices_table.go
Normal file
243
clientapi/auth/storage/devices/sqlite3/devices_table.go
Normal file
|
|
@ -0,0 +1,243 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const devicesSchema = `
|
||||||
|
-- This sequence is used for automatic allocation of session_id.
|
||||||
|
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||||
|
|
||||||
|
-- Stores data about devices.
|
||||||
|
CREATE TABLE IF NOT EXISTS device_devices (
|
||||||
|
access_token TEXT PRIMARY KEY,
|
||||||
|
session_id INTEGER,
|
||||||
|
device_id TEXT ,
|
||||||
|
localpart TEXT ,
|
||||||
|
created_ts BIGINT,
|
||||||
|
display_name TEXT,
|
||||||
|
|
||||||
|
UNIQUE (localpart, device_id)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertDeviceSQL = "" +
|
||||||
|
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
|
||||||
|
const selectDevicesCountSQL = "" +
|
||||||
|
"SELECT COUNT(access_token) FROM device_devices"
|
||||||
|
|
||||||
|
const selectDeviceByTokenSQL = "" +
|
||||||
|
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||||
|
|
||||||
|
const selectDeviceByIDSQL = "" +
|
||||||
|
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||||
|
|
||||||
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
|
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
|
||||||
|
|
||||||
|
const updateDeviceNameSQL = "" +
|
||||||
|
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
|
|
||||||
|
const deleteDeviceSQL = "" +
|
||||||
|
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||||
|
|
||||||
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
|
"DELETE FROM device_devices WHERE localpart = $1"
|
||||||
|
|
||||||
|
const deleteDevicesSQL = "" +
|
||||||
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||||
|
|
||||||
|
type devicesStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
insertDeviceStmt *sql.Stmt
|
||||||
|
selectDevicesCountStmt *sql.Stmt
|
||||||
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
|
selectDeviceByIDStmt *sql.Stmt
|
||||||
|
selectDevicesByLocalpartStmt *sql.Stmt
|
||||||
|
updateDeviceNameStmt *sql.Stmt
|
||||||
|
deleteDeviceStmt *sql.Stmt
|
||||||
|
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
s.db = db
|
||||||
|
_, err = db.Exec(devicesSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||||
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
|
// Returns the device on success.
|
||||||
|
func (s *devicesStatements) insertDevice(
|
||||||
|
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||||
|
displayName *string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
|
var sessionID int64
|
||||||
|
countStmt := common.TxStmt(txn, s.selectDevicesCountStmt)
|
||||||
|
insertStmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||||
|
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sessionID++
|
||||||
|
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &authtypes.Device{
|
||||||
|
ID: id,
|
||||||
|
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||||
|
AccessToken: accessToken,
|
||||||
|
SessionID: sessionID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) deleteDevice(
|
||||||
|
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) deleteDevices(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||||
|
) error {
|
||||||
|
orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
|
||||||
|
prep, err := s.db.Prepare(orig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stmt := common.TxStmt(txn, prep)
|
||||||
|
params := make([]interface{}, len(devices)+1)
|
||||||
|
params[0] = localpart
|
||||||
|
for i, v := range devices {
|
||||||
|
params[i+1] = v
|
||||||
|
}
|
||||||
|
params = append(params, params...)
|
||||||
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, localpart)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) updateDeviceName(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) selectDeviceByToken(
|
||||||
|
ctx context.Context, accessToken string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
|
var dev authtypes.Device
|
||||||
|
var localpart string
|
||||||
|
stmt := s.selectDeviceByTokenStmt
|
||||||
|
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||||
|
if err == nil {
|
||||||
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
dev.AccessToken = accessToken
|
||||||
|
}
|
||||||
|
return &dev, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
|
// localpart and deviceID
|
||||||
|
func (s *devicesStatements) selectDeviceByID(
|
||||||
|
ctx context.Context, localpart, deviceID string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
|
var dev authtypes.Device
|
||||||
|
var created sql.NullInt64
|
||||||
|
stmt := s.selectDeviceByIDStmt
|
||||||
|
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created)
|
||||||
|
if err == nil {
|
||||||
|
dev.ID = deviceID
|
||||||
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
}
|
||||||
|
return &dev, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) ([]authtypes.Device, error) {
|
||||||
|
devices := []authtypes.Device{}
|
||||||
|
|
||||||
|
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return devices, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var dev authtypes.Device
|
||||||
|
err = rows.Scan(&dev.ID)
|
||||||
|
if err != nil {
|
||||||
|
return devices, err
|
||||||
|
}
|
||||||
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
devices = append(devices, dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
return devices, nil
|
||||||
|
}
|
||||||
184
clientapi/auth/storage/devices/sqlite3/storage.go
Normal file
184
clientapi/auth/storage/devices/sqlite3/storage.go
Normal file
|
|
@ -0,0 +1,184 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The length of generated device IDs
|
||||||
|
var deviceIDByteLength = 6
|
||||||
|
|
||||||
|
// Database represents a device database.
|
||||||
|
type Database struct {
|
||||||
|
db *sql.DB
|
||||||
|
devices devicesStatements
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabase creates a new device database
|
||||||
|
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||||
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
|
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d := devicesStatements{}
|
||||||
|
if err = d.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{db, d}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByAccessToken(
|
||||||
|
ctx context.Context, token string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
|
return d.devices.selectDeviceByToken(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceByID returns the device matching the given ID.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByID(
|
||||||
|
ctx context.Context, localpart, deviceID string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
|
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||||
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) ([]authtypes.Device, error) {
|
||||||
|
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
// an error will be returned.
|
||||||
|
// If no device ID is given one is generated.
|
||||||
|
// Returns the device on success.
|
||||||
|
func (d *Database) CreateDevice(
|
||||||
|
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||||
|
displayName *string,
|
||||||
|
) (dev *authtypes.Device, returnErr error) {
|
||||||
|
if deviceID != nil {
|
||||||
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
// Revoke existing tokens for this device
|
||||||
|
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// We generate device IDs in a loop in case its already taken.
|
||||||
|
// We cap this at going round 5 times to ensure we don't spin forever
|
||||||
|
var newDeviceID string
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
newDeviceID, returnErr = generateDeviceID()
|
||||||
|
if returnErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if returnErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||||
|
// random bytes.
|
||||||
|
func generateDeviceID() (string, error) {
|
||||||
|
b := make([]byte, deviceIDByteLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// url-safe no padding
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDevice updates the given device with the display name.
|
||||||
|
// Returns SQL error if there are problems and nil on success.
|
||||||
|
func (d *Database) UpdateDevice(
|
||||||
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevice revokes a device by deleting the entry in the database
|
||||||
|
// matching with the given device ID and user ID localpart.
|
||||||
|
// If the device doesn't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevice(
|
||||||
|
ctx context.Context, deviceID, localpart string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||||
|
// matching with the given device IDs and user ID localpart.
|
||||||
|
// If the devices don't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevices(
|
||||||
|
ctx context.Context, localpart string, devices []string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||||
|
// database matching the given user ID localpart.
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveAllDevices(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -1,167 +1,37 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package devices
|
package devices
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"net/url"
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The length of generated device IDs
|
type Database interface {
|
||||||
var deviceIDByteLength = 6
|
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
|
||||||
|
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
|
||||||
// Database represents a device database.
|
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
|
||||||
type Database struct {
|
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
|
||||||
db *sql.DB
|
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||||
devices devicesStatements
|
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
||||||
|
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||||
|
RemoveAllDevices(ctx context.Context, localpart string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new device database
|
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
|
||||||
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
uri, err := url.Parse(dataSourceName)
|
||||||
var db *sql.DB
|
|
||||||
var err error
|
|
||||||
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
d := devicesStatements{}
|
|
||||||
if err = d.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Database{db, d}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByAccessToken(
|
|
||||||
ctx context.Context, token string,
|
|
||||||
) (*authtypes.Device, error) {
|
|
||||||
return d.devices.selectDeviceByToken(ctx, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByID returns the device matching the given ID.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByID(
|
|
||||||
ctx context.Context, localpart, deviceID string,
|
|
||||||
) (*authtypes.Device, error) {
|
|
||||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) ([]authtypes.Device, error) {
|
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
|
||||||
// an error will be returned.
|
|
||||||
// If no device ID is given one is generated.
|
|
||||||
// Returns the device on success.
|
|
||||||
func (d *Database) CreateDevice(
|
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
|
||||||
displayName *string,
|
|
||||||
) (dev *authtypes.Device, returnErr error) {
|
|
||||||
if deviceID != nil {
|
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
// Revoke existing tokens for this device
|
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// We generate device IDs in a loop in case its already taken.
|
|
||||||
// We cap this at going round 5 times to ensure we don't spin forever
|
|
||||||
var newDeviceID string
|
|
||||||
for i := 1; i <= 5; i++ {
|
|
||||||
newDeviceID, returnErr = generateDeviceID()
|
|
||||||
if returnErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if returnErr == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
|
||||||
// random bytes.
|
|
||||||
func generateDeviceID() (string, error) {
|
|
||||||
b := make([]byte, deviceIDByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return postgres.NewDatabase(dataSourceName, serverName)
|
||||||
|
}
|
||||||
|
switch uri.Scheme {
|
||||||
|
case "postgres":
|
||||||
|
return postgres.NewDatabase(dataSourceName, serverName)
|
||||||
|
case "file":
|
||||||
|
return sqlite3.NewDatabase(dataSourceName, serverName)
|
||||||
|
default:
|
||||||
|
return postgres.NewDatabase(dataSourceName, serverName)
|
||||||
}
|
}
|
||||||
// url-safe no padding
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDevice updates the given device with the display name.
|
|
||||||
// Returns SQL error if there are problems and nil on success.
|
|
||||||
func (d *Database) UpdateDevice(
|
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
|
||||||
) error {
|
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
|
||||||
// matching with the given device ID and user ID localpart.
|
|
||||||
// If the device doesn't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevice(
|
|
||||||
ctx context.Context, deviceID, localpart string,
|
|
||||||
) error {
|
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
|
||||||
// database matching the given user ID localpart.
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveAllDevices(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) error {
|
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,8 +34,8 @@ import (
|
||||||
// component.
|
// component.
|
||||||
func SetupClientAPIComponent(
|
func SetupClientAPIComponent(
|
||||||
base *basecomponent.BaseDendrite,
|
base *basecomponent.BaseDendrite,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
accountsDB *accounts.Database,
|
accountsDB accounts.Database,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
keyRing *gomatrixserverlib.KeyRing,
|
keyRing *gomatrixserverlib.KeyRing,
|
||||||
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
|
|
@ -67,7 +67,7 @@ func SetupClientAPIComponent(
|
||||||
}
|
}
|
||||||
|
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
base.APIMux, *base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI,
|
base.APIMux, base.Cfg, roomserverProducer, queryAPI, aliasAPI, asAPI,
|
||||||
accountsDB, deviceDB, federation, *keyRing, userUpdateProducer,
|
accountsDB, deviceDB, federation, *keyRing, userUpdateProducer,
|
||||||
syncProducer, typingProducer, transactionsCache, fedSenderAPI,
|
syncProducer, typingProducer, transactionsCache, fedSenderAPI,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ import (
|
||||||
// OutputRoomEventConsumer consumes events that originated in the room server.
|
// OutputRoomEventConsumer consumes events that originated in the room server.
|
||||||
type OutputRoomEventConsumer struct {
|
type OutputRoomEventConsumer struct {
|
||||||
roomServerConsumer *common.ContinualConsumer
|
roomServerConsumer *common.ContinualConsumer
|
||||||
db *accounts.Database
|
db accounts.Database
|
||||||
query api.RoomserverQueryAPI
|
query api.RoomserverQueryAPI
|
||||||
serverName string
|
serverName string
|
||||||
}
|
}
|
||||||
|
|
@ -40,7 +40,7 @@ type OutputRoomEventConsumer struct {
|
||||||
func NewOutputRoomEventConsumer(
|
func NewOutputRoomEventConsumer(
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
kafkaConsumer sarama.Consumer,
|
kafkaConsumer sarama.Consumer,
|
||||||
store *accounts.Database,
|
store accounts.Database,
|
||||||
queryAPI api.RoomserverQueryAPI,
|
queryAPI api.RoomserverQueryAPI,
|
||||||
) *OutputRoomEventConsumer {
|
) *OutputRoomEventConsumer {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ import (
|
||||||
|
|
||||||
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
|
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
|
||||||
func GetAccountData(
|
func GetAccountData(
|
||||||
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
|
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
|
||||||
userID string, roomID string, dataType string,
|
userID string, roomID string, dataType string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if userID != device.UserID {
|
||||||
|
|
@ -62,7 +62,7 @@ func GetAccountData(
|
||||||
|
|
||||||
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
|
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
|
||||||
func SaveAccountData(
|
func SaveAccountData(
|
||||||
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
|
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
|
||||||
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
|
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if userID != device.UserID {
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
|
||||||
// AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID}
|
// AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID}
|
||||||
func AuthFallback(
|
func AuthFallback(
|
||||||
w http.ResponseWriter, req *http.Request, authType string,
|
w http.ResponseWriter, req *http.Request, authType string,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
) *util.JSONResponse {
|
) *util.JSONResponse {
|
||||||
sessionID := req.URL.Query().Get("session")
|
sessionID := req.URL.Query().Get("session")
|
||||||
|
|
||||||
|
|
@ -130,7 +130,7 @@ func AuthFallback(
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
// Handle Recaptcha
|
// Handle Recaptcha
|
||||||
if authType == authtypes.LoginTypeRecaptcha {
|
if authType == authtypes.LoginTypeRecaptcha {
|
||||||
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil {
|
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,7 +144,7 @@ func AuthFallback(
|
||||||
} else if req.Method == http.MethodPost {
|
} else if req.Method == http.MethodPost {
|
||||||
// Handle Recaptcha
|
// Handle Recaptcha
|
||||||
if authType == authtypes.LoginTypeRecaptcha {
|
if authType == authtypes.LoginTypeRecaptcha {
|
||||||
if err := checkRecaptchaEnabled(&cfg, w, req); err != nil {
|
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -156,7 +156,7 @@ func AuthFallback(
|
||||||
}
|
}
|
||||||
|
|
||||||
response := req.Form.Get("g-recaptcha-response")
|
response := req.Form.Get("g-recaptcha-response")
|
||||||
if err := validateRecaptcha(&cfg, response, clientIP); err != nil {
|
if err := validateRecaptcha(cfg, response, clientIP); err != nil {
|
||||||
util.GetLogger(req.Context()).Error(err)
|
util.GetLogger(req.Context()).Error(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -134,8 +134,8 @@ type fledglingEvent struct {
|
||||||
// CreateRoom implements /createRoom
|
// CreateRoom implements /createRoom
|
||||||
func CreateRoom(
|
func CreateRoom(
|
||||||
req *http.Request, device *authtypes.Device,
|
req *http.Request, device *authtypes.Device,
|
||||||
cfg config.Dendrite, producer *producers.RoomserverProducer,
|
cfg *config.Dendrite, producer *producers.RoomserverProducer,
|
||||||
accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
|
accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
// TODO (#267): Check room ID doesn't clash with an existing one, and we
|
// TODO (#267): Check room ID doesn't clash with an existing one, and we
|
||||||
|
|
@ -148,8 +148,8 @@ func CreateRoom(
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func createRoom(
|
func createRoom(
|
||||||
req *http.Request, device *authtypes.Device,
|
req *http.Request, device *authtypes.Device,
|
||||||
cfg config.Dendrite, roomID string, producer *producers.RoomserverProducer,
|
cfg *config.Dendrite, roomID string, producer *producers.RoomserverProducer,
|
||||||
accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
|
accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
logger := util.GetLogger(req.Context())
|
logger := util.GetLogger(req.Context())
|
||||||
|
|
@ -344,7 +344,7 @@ func createRoom(
|
||||||
func buildEvent(
|
func buildEvent(
|
||||||
builder *gomatrixserverlib.EventBuilder,
|
builder *gomatrixserverlib.EventBuilder,
|
||||||
provider gomatrixserverlib.AuthEventProvider,
|
provider gomatrixserverlib.AuthEventProvider,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
evTime time.Time,
|
evTime time.Time,
|
||||||
) (*gomatrixserverlib.Event, error) {
|
) (*gomatrixserverlib.Event, error) {
|
||||||
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
||||||
|
|
|
||||||
|
|
@ -40,9 +40,13 @@ type deviceUpdateJSON struct {
|
||||||
DisplayName *string `json:"display_name"`
|
DisplayName *string `json:"display_name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type devicesDeleteJSON struct {
|
||||||
|
Devices []string `json:"devices"`
|
||||||
|
}
|
||||||
|
|
||||||
// GetDeviceByID handles /devices/{deviceID}
|
// GetDeviceByID handles /devices/{deviceID}
|
||||||
func GetDeviceByID(
|
func GetDeviceByID(
|
||||||
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
|
||||||
deviceID string,
|
deviceID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
|
@ -72,7 +76,7 @@ func GetDeviceByID(
|
||||||
|
|
||||||
// GetDevicesByLocalpart handles /devices
|
// GetDevicesByLocalpart handles /devices
|
||||||
func GetDevicesByLocalpart(
|
func GetDevicesByLocalpart(
|
||||||
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -103,7 +107,7 @@ func GetDevicesByLocalpart(
|
||||||
|
|
||||||
// UpdateDeviceByID handles PUT on /devices/{deviceID}
|
// UpdateDeviceByID handles PUT on /devices/{deviceID}
|
||||||
func UpdateDeviceByID(
|
func UpdateDeviceByID(
|
||||||
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
|
||||||
deviceID string,
|
deviceID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
|
@ -146,3 +150,54 @@ func UpdateDeviceByID(
|
||||||
JSON: struct{}{},
|
JSON: struct{}{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteDeviceById handles DELETE requests to /devices/{deviceId}
|
||||||
|
func DeleteDeviceById(
|
||||||
|
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
|
||||||
|
deviceID string,
|
||||||
|
) util.JSONResponse {
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
ctx := req.Context()
|
||||||
|
|
||||||
|
defer req.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: struct{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDevices handles POST requests to /delete_devices
|
||||||
|
func DeleteDevices(
|
||||||
|
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
|
||||||
|
) util.JSONResponse {
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := req.Context()
|
||||||
|
payload := devicesDeleteJSON{}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer req.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: struct{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ import (
|
||||||
|
|
||||||
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
|
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
|
||||||
func GetFilter(
|
func GetFilter(
|
||||||
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string,
|
req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, filterID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if userID != device.UserID {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
@ -63,7 +63,7 @@ type filterResponse struct {
|
||||||
|
|
||||||
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
|
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
|
||||||
func PutFilter(
|
func PutFilter(
|
||||||
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string,
|
req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if userID != device.UserID {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ type getEventRequest struct {
|
||||||
device *authtypes.Device
|
device *authtypes.Device
|
||||||
roomID string
|
roomID string
|
||||||
eventID string
|
eventID string
|
||||||
cfg config.Dendrite
|
cfg *config.Dendrite
|
||||||
federation *gomatrixserverlib.FederationClient
|
federation *gomatrixserverlib.FederationClient
|
||||||
keyRing gomatrixserverlib.KeyRing
|
keyRing gomatrixserverlib.KeyRing
|
||||||
requestedEvent gomatrixserverlib.Event
|
requestedEvent gomatrixserverlib.Event
|
||||||
|
|
@ -44,7 +44,7 @@ func GetEvent(
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
roomID string,
|
roomID string,
|
||||||
eventID string,
|
eventID string,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
queryAPI api.RoomserverQueryAPI,
|
queryAPI api.RoomserverQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
keyRing gomatrixserverlib.KeyRing,
|
keyRing gomatrixserverlib.KeyRing,
|
||||||
|
|
|
||||||
|
|
@ -39,13 +39,13 @@ func JoinRoomByIDOrAlias(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
roomIDOrAlias string,
|
roomIDOrAlias string,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
queryAPI roomserverAPI.RoomserverQueryAPI,
|
queryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
keyRing gomatrixserverlib.KeyRing,
|
keyRing gomatrixserverlib.KeyRing,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var content map[string]interface{} // must be a JSON object
|
var content map[string]interface{} // must be a JSON object
|
||||||
if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil {
|
if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil {
|
||||||
|
|
@ -98,7 +98,7 @@ type joinRoomReq struct {
|
||||||
evTime time.Time
|
evTime time.Time
|
||||||
content map[string]interface{}
|
content map[string]interface{}
|
||||||
userID string
|
userID string
|
||||||
cfg config.Dendrite
|
cfg *config.Dendrite
|
||||||
federation *gomatrixserverlib.FederationClient
|
federation *gomatrixserverlib.FederationClient
|
||||||
producer *producers.RoomserverProducer
|
producer *producers.RoomserverProducer
|
||||||
queryAPI roomserverAPI.RoomserverQueryAPI
|
queryAPI roomserverAPI.RoomserverQueryAPI
|
||||||
|
|
|
||||||
|
|
@ -70,8 +70,8 @@ func passwordLogin() loginFlows {
|
||||||
|
|
||||||
// Login implements GET and POST /login
|
// Login implements GET and POST /login
|
||||||
func Login(
|
func Login(
|
||||||
req *http.Request, accountDB *accounts.Database, deviceDB *devices.Database,
|
req *http.Request, accountDB accounts.Database, deviceDB devices.Database,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options
|
if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
@ -153,7 +153,7 @@ func Login(
|
||||||
func getDevice(
|
func getDevice(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
r passwordRequest,
|
r passwordRequest,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
acc *authtypes.Account,
|
acc *authtypes.Account,
|
||||||
token string,
|
token string,
|
||||||
) (dev *authtypes.Device, err error) {
|
) (dev *authtypes.Device, err error) {
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ import (
|
||||||
|
|
||||||
// Logout handles POST /logout
|
// Logout handles POST /logout
|
||||||
func Logout(
|
func Logout(
|
||||||
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -45,7 +45,7 @@ func Logout(
|
||||||
|
|
||||||
// LogoutAll handles POST /logout/all
|
// LogoutAll handles POST /logout/all
|
||||||
func LogoutAll(
|
func LogoutAll(
|
||||||
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
req *http.Request, deviceDB devices.Database, device *authtypes.Device,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -40,8 +40,8 @@ var errMissingUserID = errors.New("'user_id' must be supplied")
|
||||||
// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
|
// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
|
||||||
// by building a m.room.member event then sending it to the room server
|
// by building a m.room.member event then sending it to the room server
|
||||||
func SendMembership(
|
func SendMembership(
|
||||||
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
|
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
|
||||||
roomID string, membership string, cfg config.Dendrite,
|
roomID string, membership string, cfg *config.Dendrite,
|
||||||
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -116,10 +116,10 @@ func SendMembership(
|
||||||
|
|
||||||
func buildMembershipEvent(
|
func buildMembershipEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
body threepid.MembershipRequest, accountDB *accounts.Database,
|
body threepid.MembershipRequest, accountDB accounts.Database,
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
membership, roomID string,
|
membership, roomID string,
|
||||||
cfg config.Dendrite, evTime time.Time,
|
cfg *config.Dendrite, evTime time.Time,
|
||||||
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) (*gomatrixserverlib.Event, error) {
|
) (*gomatrixserverlib.Event, error) {
|
||||||
stateKey, reason, err := getMembershipStateKey(body, device, membership)
|
stateKey, reason, err := getMembershipStateKey(body, device, membership)
|
||||||
|
|
@ -165,8 +165,8 @@ func buildMembershipEvent(
|
||||||
func loadProfile(
|
func loadProfile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID string,
|
userID string,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
|
@ -214,9 +214,9 @@ func checkAndProcessThreepid(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
body *threepid.MembershipRequest,
|
body *threepid.MembershipRequest,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
queryAPI roomserverAPI.RoomserverQueryAPI,
|
queryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
membership, roomID string,
|
membership, roomID string,
|
||||||
evTime time.Time,
|
evTime time.Time,
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ type response struct {
|
||||||
// GetMemberships implements GET /rooms/{roomId}/members
|
// GetMemberships implements GET /rooms/{roomId}/members
|
||||||
func GetMemberships(
|
func GetMemberships(
|
||||||
req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool,
|
req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool,
|
||||||
_ config.Dendrite,
|
_ *config.Dendrite,
|
||||||
queryAPI api.RoomserverQueryAPI,
|
queryAPI api.RoomserverQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
queryReq := api.QueryMembershipsForRoomRequest{
|
queryReq := api.QueryMembershipsForRoomRequest{
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ import (
|
||||||
|
|
||||||
// GetProfile implements GET /profile/{userID}
|
// GetProfile implements GET /profile/{userID}
|
||||||
func GetProfile(
|
func GetProfile(
|
||||||
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
|
req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
|
||||||
userID string,
|
userID string,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
|
|
@ -64,7 +64,7 @@ func GetProfile(
|
||||||
|
|
||||||
// GetAvatarURL implements GET /profile/{userID}/avatar_url
|
// GetAvatarURL implements GET /profile/{userID}/avatar_url
|
||||||
func GetAvatarURL(
|
func GetAvatarURL(
|
||||||
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
|
req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
|
||||||
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -90,7 +90,7 @@ func GetAvatarURL(
|
||||||
|
|
||||||
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
|
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
|
||||||
func SetAvatarURL(
|
func SetAvatarURL(
|
||||||
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
|
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
|
||||||
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
|
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
|
||||||
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
|
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -170,7 +170,7 @@ func SetAvatarURL(
|
||||||
|
|
||||||
// GetDisplayName implements GET /profile/{userID}/displayname
|
// GetDisplayName implements GET /profile/{userID}/displayname
|
||||||
func GetDisplayName(
|
func GetDisplayName(
|
||||||
req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
|
req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
|
||||||
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -196,7 +196,7 @@ func GetDisplayName(
|
||||||
|
|
||||||
// SetDisplayName implements PUT /profile/{userID}/displayname
|
// SetDisplayName implements PUT /profile/{userID}/displayname
|
||||||
func SetDisplayName(
|
func SetDisplayName(
|
||||||
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
|
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
|
||||||
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
|
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
|
||||||
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
|
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -279,7 +279,7 @@ func SetDisplayName(
|
||||||
// Returns an error when something goes wrong or specifically
|
// Returns an error when something goes wrong or specifically
|
||||||
// common.ErrProfileNoExists when the profile doesn't exist.
|
// common.ErrProfileNoExists when the profile doesn't exist.
|
||||||
func getProfile(
|
func getProfile(
|
||||||
ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite,
|
ctx context.Context, accountDB accounts.Database, cfg *config.Dendrite,
|
||||||
userID string,
|
userID string,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
|
|
@ -343,7 +343,7 @@ func buildMembershipEvents(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
event, err := common.BuildEvent(ctx, &builder, *cfg, evTime, queryAPI, nil)
|
event, err := common.BuildEvent(ctx, &builder, cfg, evTime, queryAPI, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -440,8 +440,8 @@ func validateApplicationService(
|
||||||
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
|
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
|
||||||
func Register(
|
func Register(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
||||||
|
|
@ -513,8 +513,8 @@ func handleGuestRegistration(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
r registerRequest,
|
r registerRequest,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
||||||
//Generate numeric local part for guest user
|
//Generate numeric local part for guest user
|
||||||
|
|
@ -570,8 +570,8 @@ func handleRegistrationFlow(
|
||||||
r registerRequest,
|
r registerRequest,
|
||||||
sessionID string,
|
sessionID string,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
// TODO: Shared secret registration (create new user scripts)
|
// TODO: Shared secret registration (create new user scripts)
|
||||||
// TODO: Enable registration config flag
|
// TODO: Enable registration config flag
|
||||||
|
|
@ -668,8 +668,8 @@ func handleApplicationServiceRegistration(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
r registerRequest,
|
r registerRequest,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
// Check if we previously had issues extracting the access token from the
|
// Check if we previously had issues extracting the access token from the
|
||||||
// request.
|
// request.
|
||||||
|
|
@ -707,8 +707,8 @@ func checkAndCompleteFlow(
|
||||||
r registerRequest,
|
r registerRequest,
|
||||||
sessionID string,
|
sessionID string,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
||||||
// This flow was completed, registration can continue
|
// This flow was completed, registration can continue
|
||||||
|
|
@ -730,8 +730,8 @@ func checkAndCompleteFlow(
|
||||||
// LegacyRegister process register requests from the legacy v1 API
|
// LegacyRegister process register requests from the legacy v1 API
|
||||||
func LegacyRegister(
|
func LegacyRegister(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var r legacyRegisterRequest
|
var r legacyRegisterRequest
|
||||||
|
|
@ -814,8 +814,8 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
|
||||||
// not all
|
// not all
|
||||||
func completeRegistration(
|
func completeRegistration(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
username, password, appserviceID string,
|
username, password, appserviceID string,
|
||||||
inhibitLogin common.WeakBoolean,
|
inhibitLogin common.WeakBoolean,
|
||||||
displayName, deviceID *string,
|
displayName, deviceID *string,
|
||||||
|
|
@ -991,8 +991,8 @@ type availableResponse struct {
|
||||||
// RegisterAvailable checks if the username is already taken or invalid.
|
// RegisterAvailable checks if the username is already taken or invalid.
|
||||||
func RegisterAvailable(
|
func RegisterAvailable(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
username := req.URL.Query().Get("username")
|
username := req.URL.Query().Get("username")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ func newTag() gomatrix.TagContent {
|
||||||
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
|
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
|
||||||
func GetTags(
|
func GetTags(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
|
|
@ -77,7 +77,7 @@ func GetTags(
|
||||||
// the tag to the "map" and saving the new "map" to the DB
|
// the tag to the "map" and saving the new "map" to the DB
|
||||||
func PutTag(
|
func PutTag(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
|
|
@ -134,7 +134,7 @@ func PutTag(
|
||||||
// the "map" and then saving the new "map" in the DB
|
// the "map" and then saving the new "map" in the DB
|
||||||
func DeleteTag(
|
func DeleteTag(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
|
|
@ -203,7 +203,7 @@ func obtainSavedTags(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
) (string, *gomatrixserverlib.ClientEvent, error) {
|
) (string, *gomatrixserverlib.ClientEvent, error) {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -222,7 +222,7 @@ func saveTagData(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
localpart string,
|
localpart string,
|
||||||
roomID string,
|
roomID string,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
Tag gomatrix.TagContent,
|
Tag gomatrix.TagContent,
|
||||||
) error {
|
) error {
|
||||||
newTagData, err := json.Marshal(Tag)
|
newTagData, err := json.Marshal(Tag)
|
||||||
|
|
|
||||||
|
|
@ -47,13 +47,13 @@ const pathPrefixUnstable = "/_matrix/client/unstable"
|
||||||
// applied:
|
// applied:
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func Setup(
|
func Setup(
|
||||||
apiMux *mux.Router, cfg config.Dendrite,
|
apiMux *mux.Router, cfg *config.Dendrite,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
queryAPI roomserverAPI.RoomserverQueryAPI,
|
queryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
keyRing gomatrixserverlib.KeyRing,
|
keyRing gomatrixserverlib.KeyRing,
|
||||||
userUpdateProducer *producers.UserUpdateProducer,
|
userUpdateProducer *producers.UserUpdateProducer,
|
||||||
|
|
@ -170,11 +170,11 @@ func Setup(
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||||
return Register(req, accountDB, deviceDB, &cfg)
|
return Register(req, accountDB, deviceDB, cfg)
|
||||||
})).Methods(http.MethodPost, http.MethodOptions)
|
})).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||||
return LegacyRegister(req, accountDB, deviceDB, &cfg)
|
return LegacyRegister(req, accountDB, deviceDB, cfg)
|
||||||
})).Methods(http.MethodPost, http.MethodOptions)
|
})).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
|
r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
|
||||||
|
|
@ -187,7 +187,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return DirectoryRoom(req, vars["roomAlias"], federation, &cfg, aliasAPI, federationSender)
|
return DirectoryRoom(req, vars["roomAlias"], federation, cfg, aliasAPI, federationSender)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
|
@ -197,7 +197,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SetLocalAlias(req, device, vars["roomAlias"], &cfg, aliasAPI)
|
return SetLocalAlias(req, device, vars["roomAlias"], cfg, aliasAPI)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
|
|
@ -301,7 +301,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetProfile(req, accountDB, &cfg, vars["userID"], asAPI, federation)
|
return GetProfile(req, accountDB, cfg, vars["userID"], asAPI, federation)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
|
@ -311,7 +311,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetAvatarURL(req, accountDB, &cfg, vars["userID"], asAPI, federation)
|
return GetAvatarURL(req, accountDB, cfg, vars["userID"], asAPI, federation)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
|
@ -321,7 +321,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI)
|
return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||||
|
|
@ -333,7 +333,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetDisplayName(req, accountDB, &cfg, vars["userID"], asAPI, federation)
|
return GetDisplayName(req, accountDB, cfg, vars["userID"], asAPI, federation)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
|
@ -343,7 +343,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI)
|
return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, cfg, producer, queryAPI)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||||
|
|
@ -503,6 +503,22 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
|
r0mux.Handle("/devices/{deviceID}",
|
||||||
|
common.MakeAuthAPI("delete_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||||
|
vars, err := common.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return DeleteDeviceById(req, deviceDB, device, vars["deviceID"])
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodDelete, http.MethodOptions)
|
||||||
|
|
||||||
|
r0mux.Handle("/delete_devices",
|
||||||
|
common.MakeAuthAPI("delete_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||||
|
return DeleteDevices(req, deviceDB, device)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
// Stub implementations for sytest
|
// Stub implementations for sytest
|
||||||
r0mux.Handle("/events",
|
r0mux.Handle("/events",
|
||||||
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ func SendEvent(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
roomID, eventType string, txnID, stateKey *string,
|
roomID, eventType string, txnID, stateKey *string,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
queryAPI api.RoomserverQueryAPI,
|
queryAPI api.RoomserverQueryAPI,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
txnCache *transactions.Cache,
|
txnCache *transactions.Cache,
|
||||||
|
|
@ -93,7 +93,7 @@ func generateSendEvent(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
device *authtypes.Device,
|
device *authtypes.Device,
|
||||||
roomID, eventType string, stateKey *string,
|
roomID, eventType string, stateKey *string,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
queryAPI api.RoomserverQueryAPI,
|
queryAPI api.RoomserverQueryAPI,
|
||||||
) (*gomatrixserverlib.Event, *util.JSONResponse) {
|
) (*gomatrixserverlib.Event, *util.JSONResponse) {
|
||||||
// parse the incoming http request
|
// parse the incoming http request
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ type typingContentJSON struct {
|
||||||
// sends the typing events to client API typingProducer
|
// sends the typing events to client API typingProducer
|
||||||
func SendTyping(
|
func SendTyping(
|
||||||
req *http.Request, device *authtypes.Device, roomID string,
|
req *http.Request, device *authtypes.Device, roomID string,
|
||||||
userID string, accountDB *accounts.Database,
|
userID string, accountDB accounts.Database,
|
||||||
typingProducer *producers.TypingServerProducer,
|
typingProducer *producers.TypingServerProducer,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if device.UserID != userID {
|
if device.UserID != userID {
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ type threePIDsResponse struct {
|
||||||
// RequestEmailToken implements:
|
// RequestEmailToken implements:
|
||||||
// POST /account/3pid/email/requestToken
|
// POST /account/3pid/email/requestToken
|
||||||
// POST /register/email/requestToken
|
// POST /register/email/requestToken
|
||||||
func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg config.Dendrite) util.JSONResponse {
|
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.Dendrite) util.JSONResponse {
|
||||||
var body threepid.EmailAssociationRequest
|
var body threepid.EmailAssociationRequest
|
||||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||||
return *reqErr
|
return *reqErr
|
||||||
|
|
@ -82,8 +82,8 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
|
||||||
|
|
||||||
// CheckAndSave3PIDAssociation implements POST /account/3pid
|
// CheckAndSave3PIDAssociation implements POST /account/3pid
|
||||||
func CheckAndSave3PIDAssociation(
|
func CheckAndSave3PIDAssociation(
|
||||||
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
|
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var body threepid.EmailAssociationCheckRequest
|
var body threepid.EmailAssociationCheckRequest
|
||||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||||
|
|
@ -142,7 +142,7 @@ func CheckAndSave3PIDAssociation(
|
||||||
|
|
||||||
// GetAssociated3PIDs implements GET /account/3pid
|
// GetAssociated3PIDs implements GET /account/3pid
|
||||||
func GetAssociated3PIDs(
|
func GetAssociated3PIDs(
|
||||||
req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
|
req *http.Request, accountDB accounts.Database, device *authtypes.Device,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -161,7 +161,7 @@ func GetAssociated3PIDs(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forget3PID implements POST /account/3pid/delete
|
// Forget3PID implements POST /account/3pid/delete
|
||||||
func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONResponse {
|
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
|
||||||
var body authtypes.ThreePID
|
var body authtypes.ThreePID
|
||||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||||
return *reqErr
|
return *reqErr
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ import (
|
||||||
|
|
||||||
// RequestTurnServer implements:
|
// RequestTurnServer implements:
|
||||||
// GET /voip/turnServer
|
// GET /voip/turnServer
|
||||||
func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.Dendrite) util.JSONResponse {
|
func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg *config.Dendrite) util.JSONResponse {
|
||||||
turnConfig := cfg.TURN
|
turnConfig := cfg.TURN
|
||||||
|
|
||||||
// TODO Guest Support
|
// TODO Guest Support
|
||||||
|
|
|
||||||
|
|
@ -86,8 +86,8 @@ var (
|
||||||
// can be emitted.
|
// can be emitted.
|
||||||
func CheckAndProcessInvite(
|
func CheckAndProcessInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
device *authtypes.Device, body *MembershipRequest, cfg config.Dendrite,
|
device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite,
|
||||||
queryAPI api.RoomserverQueryAPI, db *accounts.Database,
|
queryAPI api.RoomserverQueryAPI, db accounts.Database,
|
||||||
producer *producers.RoomserverProducer, membership string, roomID string,
|
producer *producers.RoomserverProducer, membership string, roomID string,
|
||||||
evTime time.Time,
|
evTime time.Time,
|
||||||
) (inviteStoredOnIDServer bool, err error) {
|
) (inviteStoredOnIDServer bool, err error) {
|
||||||
|
|
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
|
||||||
// Returns an error if a check or a request failed.
|
// Returns an error if a check or a request failed.
|
||||||
func queryIDServer(
|
func queryIDServer(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
|
db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
|
||||||
body *MembershipRequest, roomID string,
|
body *MembershipRequest, roomID string,
|
||||||
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
|
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
|
||||||
if err = isTrusted(body.IDServer, cfg); err != nil {
|
if err = isTrusted(body.IDServer, cfg); err != nil {
|
||||||
|
|
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
|
||||||
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
||||||
func queryIDServerStoreInvite(
|
func queryIDServerStoreInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
|
db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
|
||||||
body *MembershipRequest, roomID string,
|
body *MembershipRequest, roomID string,
|
||||||
) (*idServerStoreInviteResponse, error) {
|
) (*idServerStoreInviteResponse, error) {
|
||||||
// Retrieve the sender's profile to get their display name
|
// Retrieve the sender's profile to get their display name
|
||||||
|
|
@ -330,7 +330,7 @@ func checkIDServerSignatures(
|
||||||
func emit3PIDInviteEvent(
|
func emit3PIDInviteEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
body *MembershipRequest, res *idServerStoreInviteResponse,
|
body *MembershipRequest, res *idServerStoreInviteResponse,
|
||||||
device *authtypes.Device, roomID string, cfg config.Dendrite,
|
device *authtypes.Device, roomID string, cfg *config.Dendrite,
|
||||||
queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer,
|
queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer,
|
||||||
evTime time.Time,
|
evTime time.Time,
|
||||||
) error {
|
) error {
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ type Credentials struct {
|
||||||
// Returns an error if there was a problem sending the request or decoding the
|
// Returns an error if there was a problem sending the request or decoding the
|
||||||
// response, or if the identity server responded with a non-OK status.
|
// response, or if the identity server responded with a non-OK status.
|
||||||
func CreateSession(
|
func CreateSession(
|
||||||
ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite,
|
ctx context.Context, req EmailAssociationRequest, cfg *config.Dendrite,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
if err := isTrusted(req.IDServer, cfg); err != nil {
|
if err := isTrusted(req.IDServer, cfg); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|
@ -101,7 +101,7 @@ func CreateSession(
|
||||||
// Returns an error if there was a problem sending the request or decoding the
|
// Returns an error if there was a problem sending the request or decoding the
|
||||||
// response, or if the identity server responded with a non-OK status.
|
// response, or if the identity server responded with a non-OK status.
|
||||||
func CheckAssociation(
|
func CheckAssociation(
|
||||||
ctx context.Context, creds Credentials, cfg config.Dendrite,
|
ctx context.Context, creds Credentials, cfg *config.Dendrite,
|
||||||
) (bool, string, string, error) {
|
) (bool, string, string, error) {
|
||||||
if err := isTrusted(creds.IDServer, cfg); err != nil {
|
if err := isTrusted(creds.IDServer, cfg); err != nil {
|
||||||
return false, "", "", err
|
return false, "", "", err
|
||||||
|
|
@ -142,7 +142,7 @@ func CheckAssociation(
|
||||||
// identifier and a Matrix ID.
|
// identifier and a Matrix ID.
|
||||||
// Returns an error if there was a problem sending the request or decoding the
|
// Returns an error if there was a problem sending the request or decoding the
|
||||||
// response, or if the identity server responded with a non-OK status.
|
// response, or if the identity server responded with a non-OK status.
|
||||||
func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) error {
|
func PublishAssociation(creds Credentials, userID string, cfg *config.Dendrite) error {
|
||||||
if err := isTrusted(creds.IDServer, cfg); err != nil {
|
if err := isTrusted(creds.IDServer, cfg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -177,7 +177,7 @@ func PublishAssociation(creds Credentials, userID string, cfg config.Dendrite) e
|
||||||
// isTrusted checks if a given identity server is part of the list of trusted
|
// isTrusted checks if a given identity server is part of the list of trusted
|
||||||
// identity servers in the configuration file.
|
// identity servers in the configuration file.
|
||||||
// Returns an error if the server isn't trusted.
|
// Returns an error if the server isn't trusted.
|
||||||
func isTrusted(idServer string, cfg config.Dendrite) error {
|
func isTrusted(idServer string, cfg *config.Dendrite) error {
|
||||||
for _, server := range cfg.Matrix.TrustedIDServers {
|
for _, server := range cfg.Matrix.TrustedIDServers {
|
||||||
if idServer == server {
|
if idServer == server {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
sarama "gopkg.in/Shopify/sarama.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usage = `Usage: %s
|
const usage = `Usage: %s
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
|
|
||||||
|
|
@ -68,7 +69,13 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite {
|
||||||
logrus.WithError(err).Panicf("failed to start opentracing")
|
logrus.WithError(err).Panicf("failed to start opentracing")
|
||||||
}
|
}
|
||||||
|
|
||||||
kafkaConsumer, kafkaProducer := setupKafka(cfg)
|
var kafkaConsumer sarama.Consumer
|
||||||
|
var kafkaProducer sarama.SyncProducer
|
||||||
|
if cfg.Kafka.UseNaffka {
|
||||||
|
kafkaConsumer, kafkaProducer = setupNaffka(cfg)
|
||||||
|
} else {
|
||||||
|
kafkaConsumer, kafkaProducer = setupKafka(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
return &BaseDendrite{
|
return &BaseDendrite{
|
||||||
componentName: componentName,
|
componentName: componentName,
|
||||||
|
|
@ -118,7 +125,7 @@ func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.Fede
|
||||||
|
|
||||||
// CreateDeviceDB creates a new instance of the device database. Should only be
|
// CreateDeviceDB creates a new instance of the device database. Should only be
|
||||||
// called once per component.
|
// called once per component.
|
||||||
func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
|
func (b *BaseDendrite) CreateDeviceDB() devices.Database {
|
||||||
db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName)
|
db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to devices db")
|
logrus.WithError(err).Panicf("failed to connect to devices db")
|
||||||
|
|
@ -129,7 +136,7 @@ func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
|
||||||
|
|
||||||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||||
// be called once per component.
|
// be called once per component.
|
||||||
func (b *BaseDendrite) CreateAccountsDB() *accounts.Database {
|
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
|
||||||
db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName)
|
db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
||||||
|
|
@ -186,28 +193,8 @@ func (b *BaseDendrite) SetupAndServeHTTP(bindaddr string, listenaddr string) {
|
||||||
logrus.Infof("Stopped %s server on %s", b.componentName, addr)
|
logrus.Infof("Stopped %s server on %s", b.componentName, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupKafka creates kafka consumer/producer pair from the config. Checks if
|
// setupKafka creates kafka consumer/producer pair from the config.
|
||||||
// should use naffka.
|
|
||||||
func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
|
func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
|
||||||
if cfg.Kafka.UseNaffka {
|
|
||||||
db, err := sql.Open("postgres", string(cfg.Database.Naffka))
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Panic("Failed to open naffka database")
|
|
||||||
}
|
|
||||||
|
|
||||||
naffkaDB, err := naffka.NewPostgresqlDatabase(db)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Panic("Failed to setup naffka database")
|
|
||||||
}
|
|
||||||
|
|
||||||
naff, err := naffka.New(naffkaDB)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Panic("Failed to setup naffka")
|
|
||||||
}
|
|
||||||
|
|
||||||
return naff, naff
|
|
||||||
}
|
|
||||||
|
|
||||||
consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
|
consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panic("failed to start kafka consumer")
|
logrus.WithError(err).Panic("failed to start kafka consumer")
|
||||||
|
|
@ -220,3 +207,44 @@ func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
|
||||||
|
|
||||||
return consumer, producer
|
return consumer, producer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupNaffka creates kafka consumer/producer pair from the config.
|
||||||
|
func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
|
||||||
|
var err error
|
||||||
|
var db *sql.DB
|
||||||
|
var naffkaDB *naffka.DatabaseImpl
|
||||||
|
|
||||||
|
uri, err := url.Parse(string(cfg.Database.Naffka))
|
||||||
|
if err != nil || uri.Scheme == "file" {
|
||||||
|
db, err = sql.Open("sqlite3", string(cfg.Database.Naffka))
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panic("Failed to open naffka database")
|
||||||
|
}
|
||||||
|
|
||||||
|
naffkaDB, err = naffka.NewSqliteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panic("Failed to setup naffka database")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
db, err = sql.Open("postgres", string(cfg.Database.Naffka))
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panic("Failed to open naffka database")
|
||||||
|
}
|
||||||
|
|
||||||
|
naffkaDB, err = naffka.NewPostgresqlDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panic("Failed to setup naffka database")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if naffkaDB == nil {
|
||||||
|
panic("naffka connection string not understood")
|
||||||
|
}
|
||||||
|
|
||||||
|
naff, err := naffka.New(naffkaDB)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panic("Failed to setup naffka")
|
||||||
|
}
|
||||||
|
|
||||||
|
return naff, naff
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ var ErrRoomNoExists = errors.New("Room does not exist")
|
||||||
// Returns an error if something else went wrong
|
// Returns an error if something else went wrong
|
||||||
func BuildEvent(
|
func BuildEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
builder *gomatrixserverlib.EventBuilder, cfg config.Dendrite, evTime time.Time,
|
builder *gomatrixserverlib.EventBuilder, cfg *config.Dendrite, evTime time.Time,
|
||||||
queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse,
|
queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse,
|
||||||
) (*gomatrixserverlib.Event, error) {
|
) (*gomatrixserverlib.Event, error) {
|
||||||
err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes)
|
err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes)
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common/keydb/postgres"
|
"github.com/matrix-org/dendrite/common/keydb/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/common/keydb/sqlite3"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -44,6 +45,8 @@ func NewDatabase(
|
||||||
switch uri.Scheme {
|
switch uri.Scheme {
|
||||||
case "postgres":
|
case "postgres":
|
||||||
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
|
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
|
||||||
|
case "file":
|
||||||
|
return sqlite3.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
|
||||||
default:
|
default:
|
||||||
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
|
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
|
||||||
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serverKeyStatements) upsertServerKeys(
|
func (s *serverKeyStatements) upsertServerKeys(
|
||||||
|
|
|
||||||
115
common/keydb/sqlite3/keydb.go
Normal file
115
common/keydb/sqlite3/keydb.go
Normal file
|
|
@ -0,0 +1,115 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ed25519"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
|
||||||
|
// the public keys for other matrix servers.
|
||||||
|
type Database struct {
|
||||||
|
statements serverKeyStatements
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabase prepares a new key database.
|
||||||
|
// It creates the necessary tables if they don't already exist.
|
||||||
|
// It prepares all the SQL statements that it will use.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func NewDatabase(
|
||||||
|
dataSourceName string,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
serverKey ed25519.PublicKey,
|
||||||
|
serverKeyID gomatrixserverlib.KeyID,
|
||||||
|
) (*Database, error) {
|
||||||
|
db, err := sql.Open("sqlite3", dataSourceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d := &Database{}
|
||||||
|
err = d.statements.prepare(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Store our own keys so that we don't end up making HTTP requests to find our
|
||||||
|
// own keys
|
||||||
|
index := gomatrixserverlib.PublicKeyLookupRequest{
|
||||||
|
ServerName: serverName,
|
||||||
|
KeyID: serverKeyID,
|
||||||
|
}
|
||||||
|
value := gomatrixserverlib.PublicKeyLookupResult{
|
||||||
|
VerifyKey: gomatrixserverlib.VerifyKey{
|
||||||
|
Key: gomatrixserverlib.Base64String(serverKey),
|
||||||
|
},
|
||||||
|
ValidUntilTS: math.MaxUint64 >> 1,
|
||||||
|
ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
|
||||||
|
}
|
||||||
|
err = d.StoreKeys(
|
||||||
|
context.Background(),
|
||||||
|
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{
|
||||||
|
index: value,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetcherName implements KeyFetcher
|
||||||
|
func (d Database) FetcherName() string {
|
||||||
|
return "KeyDatabase"
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||||
|
func (d *Database) FetchKeys(
|
||||||
|
ctx context.Context,
|
||||||
|
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||||
|
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||||
|
return d.statements.bulkSelectServerKeys(ctx, requests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||||
|
func (d *Database) StoreKeys(
|
||||||
|
ctx context.Context,
|
||||||
|
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||||
|
) error {
|
||||||
|
// TODO: Inserting all the keys within a single transaction may
|
||||||
|
// be more efficient since the transaction overhead can be quite
|
||||||
|
// high for a single insert statement.
|
||||||
|
var lastErr error
|
||||||
|
for request, keys := range keyMap {
|
||||||
|
if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
|
||||||
|
// Rather than returning immediately on error we try to insert the
|
||||||
|
// remaining keys.
|
||||||
|
// Since we are inserting the keys outside of a transaction it is
|
||||||
|
// possible for some of the inserts to succeed even though some
|
||||||
|
// of the inserts have failed.
|
||||||
|
// Ensuring that we always insert all the keys we can means that
|
||||||
|
// this behaviour won't depend on the iteration order of the map.
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
142
common/keydb/sqlite3/server_key_table.go
Normal file
142
common/keydb/sqlite3/server_key_table.go
Normal file
|
|
@ -0,0 +1,142 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const serverKeysSchema = `
|
||||||
|
-- A cache of signing keys downloaded from remote servers.
|
||||||
|
CREATE TABLE IF NOT EXISTS keydb_server_keys (
|
||||||
|
-- The name of the matrix server the key is for.
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
|
-- The ID of the server key.
|
||||||
|
server_key_id TEXT NOT NULL,
|
||||||
|
-- Combined server name and key ID separated by the ASCII unit separator
|
||||||
|
-- to make it easier to run bulk queries.
|
||||||
|
server_name_and_key_id TEXT NOT NULL,
|
||||||
|
-- When the key is valid until as a millisecond timestamp.
|
||||||
|
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
|
||||||
|
valid_until_ts BIGINT NOT NULL,
|
||||||
|
-- When the key expired as a millisecond timestamp.
|
||||||
|
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
|
||||||
|
expired_ts BIGINT NOT NULL,
|
||||||
|
-- The base64-encoded public key.
|
||||||
|
server_key TEXT NOT NULL,
|
||||||
|
UNIQUE (server_name, server_key_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const bulkSelectServerKeysSQL = "" +
|
||||||
|
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
|
||||||
|
" server_key FROM keydb_server_keys" +
|
||||||
|
" WHERE server_name_and_key_id IN ($1)"
|
||||||
|
|
||||||
|
const upsertServerKeysSQL = "" +
|
||||||
|
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
|
||||||
|
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
|
" ON CONFLICT (server_name, server_key_id)" +
|
||||||
|
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
|
||||||
|
|
||||||
|
type serverKeyStatements struct {
|
||||||
|
bulkSelectServerKeysStmt *sql.Stmt
|
||||||
|
upsertServerKeysStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(serverKeysSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverKeyStatements) bulkSelectServerKeys(
|
||||||
|
ctx context.Context,
|
||||||
|
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||||
|
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||||
|
var nameAndKeyIDs []string
|
||||||
|
for request := range requests {
|
||||||
|
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
||||||
|
}
|
||||||
|
stmt := s.bulkSelectServerKeysStmt
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
|
||||||
|
for rows.Next() {
|
||||||
|
var serverName string
|
||||||
|
var keyID string
|
||||||
|
var key string
|
||||||
|
var validUntilTS int64
|
||||||
|
var expiredTS int64
|
||||||
|
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r := gomatrixserverlib.PublicKeyLookupRequest{
|
||||||
|
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||||
|
KeyID: gomatrixserverlib.KeyID(keyID),
|
||||||
|
}
|
||||||
|
vk := gomatrixserverlib.VerifyKey{}
|
||||||
|
err = vk.Key.Decode(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results[r] = gomatrixserverlib.PublicKeyLookupResult{
|
||||||
|
VerifyKey: vk,
|
||||||
|
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
|
||||||
|
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverKeyStatements) upsertServerKeys(
|
||||||
|
ctx context.Context,
|
||||||
|
request gomatrixserverlib.PublicKeyLookupRequest,
|
||||||
|
key gomatrixserverlib.PublicKeyLookupResult,
|
||||||
|
) error {
|
||||||
|
_, err := s.upsertServerKeysStmt.ExecContext(
|
||||||
|
ctx,
|
||||||
|
string(request.ServerName),
|
||||||
|
string(request.KeyID),
|
||||||
|
nameAndKeyID(request),
|
||||||
|
key.ValidUntilTS,
|
||||||
|
key.ExpiredTS,
|
||||||
|
key.Key.Encode(),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
|
||||||
|
return string(request.ServerName) + "\x1F" + string(request.KeyID)
|
||||||
|
}
|
||||||
|
|
@ -29,7 +29,7 @@ CREATE TABLE IF NOT EXISTS ${prefix}_partition_offsets (
|
||||||
partition INTEGER NOT NULL,
|
partition INTEGER NOT NULL,
|
||||||
-- The 64-bit offset.
|
-- The 64-bit offset.
|
||||||
partition_offset BIGINT NOT NULL,
|
partition_offset BIGINT NOT NULL,
|
||||||
CONSTRAINT ${prefix}_topic_partition_unique UNIQUE (topic, partition)
|
UNIQUE (topic, partition)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ const selectPartitionOffsetsSQL = "" +
|
||||||
|
|
||||||
const upsertPartitionOffsetsSQL = "" +
|
const upsertPartitionOffsetsSQL = "" +
|
||||||
"INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" +
|
"INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" +
|
||||||
" ON CONFLICT ON CONSTRAINT ${prefix}_topic_partition_unique" +
|
" ON CONFLICT (topic, partition)" +
|
||||||
" DO UPDATE SET partition_offset = $3"
|
" DO UPDATE SET partition_offset = $3"
|
||||||
|
|
||||||
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
|
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
|
||||||
|
|
@ -99,7 +99,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
|
||||||
}
|
}
|
||||||
results = append(results, offset)
|
results = append(results, offset)
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpsertPartitionOffset updates or inserts the partition offset for the given topic.
|
// UpsertPartitionOffset updates or inserts the partition offset for the given topic.
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
@ -30,11 +31,13 @@ type Transaction interface {
|
||||||
|
|
||||||
// EndTransaction ends a transaction.
|
// EndTransaction ends a transaction.
|
||||||
// If the transaction succeeded then it is committed, otherwise it is rolledback.
|
// If the transaction succeeded then it is committed, otherwise it is rolledback.
|
||||||
func EndTransaction(txn Transaction, succeeded *bool) {
|
// You MUST check the error returned from this function to be sure that the transaction
|
||||||
|
// was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
|
||||||
|
func EndTransaction(txn Transaction, succeeded *bool) error {
|
||||||
if *succeeded {
|
if *succeeded {
|
||||||
txn.Commit() // nolint: errcheck
|
return txn.Commit() // nolint: errcheck
|
||||||
} else {
|
} else {
|
||||||
txn.Rollback() // nolint: errcheck
|
return txn.Rollback() // nolint: errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,7 +50,12 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
succeeded := false
|
succeeded := false
|
||||||
defer EndTransaction(txn, &succeeded)
|
defer func() {
|
||||||
|
err2 := EndTransaction(txn, &succeeded)
|
||||||
|
if err == nil && err2 != nil { // failed to commit/rollback
|
||||||
|
err = err2
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
err = fn(txn)
|
err = fn(txn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -74,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool {
|
||||||
pqErr, ok := err.(*pq.Error)
|
pqErr, ok := err.(*pq.Error)
|
||||||
return ok && pqErr.Code == "23505"
|
return ok && pqErr.Code == "23505"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hack of the century
|
||||||
|
func QueryVariadic(count int) string {
|
||||||
|
return QueryVariadicOffset(count, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryVariadicOffset(count, offset int) string {
|
||||||
|
str := "("
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
str += fmt.Sprintf("$%d", i+offset+1)
|
||||||
|
if i < (count - 1) {
|
||||||
|
str += ", "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
str += ")"
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -111,6 +111,7 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
|
||||||
// Bind to the same address as the listen address
|
// Bind to the same address as the listen address
|
||||||
// All microservices are run on the same host in testing
|
// All microservices are run on the same host in testing
|
||||||
cfg.Bind.ClientAPI = cfg.Listen.ClientAPI
|
cfg.Bind.ClientAPI = cfg.Listen.ClientAPI
|
||||||
|
cfg.Bind.AppServiceAPI = cfg.Listen.AppServiceAPI
|
||||||
cfg.Bind.FederationAPI = cfg.Listen.FederationAPI
|
cfg.Bind.FederationAPI = cfg.Listen.FederationAPI
|
||||||
cfg.Bind.MediaAPI = cfg.Listen.MediaAPI
|
cfg.Bind.MediaAPI = cfg.Listen.MediaAPI
|
||||||
cfg.Bind.RoomServer = cfg.Listen.RoomServer
|
cfg.Bind.RoomServer = cfg.Listen.RoomServer
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,13 @@
|
||||||
|
<<<<<<< HEAD
|
||||||
|
FROM docker.io/golang:1.13.7-alpine3.11
|
||||||
|
=======
|
||||||
FROM docker.io/golang:1.13.6-alpine
|
FROM docker.io/golang:1.13.6-alpine
|
||||||
|
>>>>>>> master
|
||||||
|
|
||||||
RUN mkdir /build
|
RUN mkdir /build
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
RUN apk --update --no-cache add openssl bash git
|
RUN apk --update --no-cache add openssl bash git build-base
|
||||||
|
|
||||||
CMD ["bash", "docker/build.sh"]
|
CMD ["bash", "docker/build.sh"]
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,21 @@
|
||||||
version: "3.4"
|
version: "3.4"
|
||||||
services:
|
services:
|
||||||
|
riot:
|
||||||
|
image: vectorim/riot-web
|
||||||
|
networks:
|
||||||
|
- internal
|
||||||
|
ports:
|
||||||
|
- "8500:80"
|
||||||
|
|
||||||
monolith:
|
monolith:
|
||||||
container_name: dendrite_monolith
|
container_name: dendrite_monolith
|
||||||
hostname: monolith
|
hostname: monolith
|
||||||
entrypoint: ["bash", "./docker/services/monolith.sh"]
|
entrypoint: ["bash", "./docker/services/monolith.sh", "--config", "/etc/dendrite/dendrite.yaml"]
|
||||||
build: ./
|
build: ./
|
||||||
volumes:
|
volumes:
|
||||||
- ..:/build
|
- ..:/build
|
||||||
- ./build/bin:/build/bin
|
- ./build/bin:/build/bin
|
||||||
|
- ../cfg:/etc/dendrite
|
||||||
networks:
|
networks:
|
||||||
- internal
|
- internal
|
||||||
depends_on:
|
depends_on:
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ args:
|
||||||
user: dendrite
|
user: dendrite
|
||||||
database: dendrite
|
database: dendrite
|
||||||
host: 127.0.0.1
|
host: 127.0.0.1
|
||||||
|
sslmode: disable
|
||||||
type: pg
|
type: pg
|
||||||
EOF
|
EOF
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,8 @@ import (
|
||||||
// FederationAPI component.
|
// FederationAPI component.
|
||||||
func SetupFederationAPIComponent(
|
func SetupFederationAPIComponent(
|
||||||
base *basecomponent.BaseDendrite,
|
base *basecomponent.BaseDendrite,
|
||||||
accountsDB *accounts.Database,
|
accountsDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
keyRing *gomatrixserverlib.KeyRing,
|
keyRing *gomatrixserverlib.KeyRing,
|
||||||
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
|
|
@ -45,8 +45,8 @@ func SetupFederationAPIComponent(
|
||||||
roomserverProducer := producers.NewRoomserverProducer(inputAPI)
|
roomserverProducer := producers.NewRoomserverProducer(inputAPI)
|
||||||
|
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
base.APIMux, *base.Cfg, queryAPI, aliasAPI, asAPI,
|
base.APIMux, base.Cfg, queryAPI, aliasAPI, asAPI,
|
||||||
roomserverProducer, federationSenderAPI, *keyRing, federation, accountsDB,
|
roomserverProducer, federationSenderAPI, *keyRing,
|
||||||
deviceDB,
|
federation, accountsDB, deviceDB,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ func Backfill(
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
query api.RoomserverQueryAPI,
|
query api.RoomserverQueryAPI,
|
||||||
roomID string,
|
roomID string,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var res api.QueryBackfillResponse
|
var res api.QueryBackfillResponse
|
||||||
var eIDs []string
|
var eIDs []string
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ type userDevicesResponse struct {
|
||||||
// GetUserDevices for the given user id
|
// GetUserDevices for the given user id
|
||||||
func GetUserDevices(
|
func GetUserDevices(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
userID string,
|
userID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, err := userutil.ParseUsernameParam(userID, nil)
|
localpart, err := userutil.ParseUsernameParam(userID, nil)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ func Invite(
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
roomID string,
|
roomID string,
|
||||||
eventID string,
|
eventID string,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
keys gomatrixserverlib.KeyRing,
|
keys gomatrixserverlib.KeyRing,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ import (
|
||||||
func MakeJoin(
|
func MakeJoin(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
query api.RoomserverQueryAPI,
|
query api.RoomserverQueryAPI,
|
||||||
roomID, userID string,
|
roomID, userID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -97,7 +97,7 @@ func MakeJoin(
|
||||||
func SendJoin(
|
func SendJoin(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
query api.RoomserverQueryAPI,
|
query api.RoomserverQueryAPI,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
keys gomatrixserverlib.KeyRing,
|
keys gomatrixserverlib.KeyRing,
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ import (
|
||||||
|
|
||||||
// LocalKeys returns the local keys for the server.
|
// LocalKeys returns the local keys for the server.
|
||||||
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
|
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
|
||||||
func LocalKeys(cfg config.Dendrite) util.JSONResponse {
|
func LocalKeys(cfg *config.Dendrite) util.JSONResponse {
|
||||||
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
|
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
|
@ -35,7 +35,7 @@ func LocalKeys(cfg config.Dendrite) util.JSONResponse {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: keys}
|
return util.JSONResponse{Code: http.StatusOK, JSON: keys}
|
||||||
}
|
}
|
||||||
|
|
||||||
func localKeys(cfg config.Dendrite, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
|
func localKeys(cfg *config.Dendrite, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
|
||||||
var keys gomatrixserverlib.ServerKeys
|
var keys gomatrixserverlib.ServerKeys
|
||||||
|
|
||||||
keys.ServerName = cfg.Matrix.ServerName
|
keys.ServerName = cfg.Matrix.ServerName
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ import (
|
||||||
func MakeLeave(
|
func MakeLeave(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
query api.RoomserverQueryAPI,
|
query api.RoomserverQueryAPI,
|
||||||
roomID, userID string,
|
roomID, userID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -95,7 +95,7 @@ func MakeLeave(
|
||||||
func SendLeave(
|
func SendLeave(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
keys gomatrixserverlib.KeyRing,
|
keys gomatrixserverlib.KeyRing,
|
||||||
roomID, eventID string,
|
roomID, eventID string,
|
||||||
|
|
|
||||||
|
|
@ -30,8 +30,8 @@ import (
|
||||||
// GetProfile implements GET /_matrix/federation/v1/query/profile
|
// GetProfile implements GET /_matrix/federation/v1/query/profile
|
||||||
func GetProfile(
|
func GetProfile(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field")
|
userID, field := httpReq.FormValue("user_id"), httpReq.FormValue("field")
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ import (
|
||||||
func RoomAliasToID(
|
func RoomAliasToID(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
senderAPI federationSenderAPI.FederationSenderQueryAPI,
|
senderAPI federationSenderAPI.FederationSenderQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ const (
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func Setup(
|
func Setup(
|
||||||
apiMux *mux.Router,
|
apiMux *mux.Router,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
query roomserverAPI.RoomserverQueryAPI,
|
query roomserverAPI.RoomserverQueryAPI,
|
||||||
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
aliasAPI roomserverAPI.RoomserverAliasAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
|
|
@ -51,8 +51,8 @@ func Setup(
|
||||||
federationSenderAPI federationSenderAPI.FederationSenderQueryAPI,
|
federationSenderAPI federationSenderAPI.FederationSenderQueryAPI,
|
||||||
keys gomatrixserverlib.KeyRing,
|
keys gomatrixserverlib.KeyRing,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
) {
|
) {
|
||||||
v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter()
|
v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter()
|
||||||
v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter()
|
v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter()
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ func Send(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
txnID gomatrixserverlib.TransactionID,
|
txnID gomatrixserverlib.TransactionID,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
query api.RoomserverQueryAPI,
|
query api.RoomserverQueryAPI,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
keys gomatrixserverlib.KeyRing,
|
keys gomatrixserverlib.KeyRing,
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,9 @@ var (
|
||||||
// CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind
|
// CreateInvitesFrom3PIDInvites implements POST /_matrix/federation/v1/3pid/onbind
|
||||||
func CreateInvitesFrom3PIDInvites(
|
func CreateInvitesFrom3PIDInvites(
|
||||||
req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI,
|
req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI, cfg config.Dendrite,
|
asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite,
|
||||||
producer *producers.RoomserverProducer, federation *gomatrixserverlib.FederationClient,
|
producer *producers.RoomserverProducer, federation *gomatrixserverlib.FederationClient,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var body invites
|
var body invites
|
||||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||||
|
|
@ -98,7 +98,7 @@ func ExchangeThirdPartyInvite(
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
roomID string,
|
roomID string,
|
||||||
queryAPI roomserverAPI.RoomserverQueryAPI,
|
queryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
producer *producers.RoomserverProducer,
|
producer *producers.RoomserverProducer,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -172,9 +172,9 @@ func ExchangeThirdPartyInvite(
|
||||||
// necessary data to do so.
|
// necessary data to do so.
|
||||||
func createInviteFrom3PIDInvite(
|
func createInviteFrom3PIDInvite(
|
||||||
ctx context.Context, queryAPI roomserverAPI.RoomserverQueryAPI,
|
ctx context.Context, queryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI, cfg config.Dendrite,
|
asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite,
|
||||||
inv invite, federation *gomatrixserverlib.FederationClient,
|
inv invite, federation *gomatrixserverlib.FederationClient,
|
||||||
accountDB *accounts.Database,
|
accountDB accounts.Database,
|
||||||
) (*gomatrixserverlib.Event, error) {
|
) (*gomatrixserverlib.Event, error) {
|
||||||
_, server, err := gomatrixserverlib.SplitID('@', inv.MXID)
|
_, server, err := gomatrixserverlib.SplitID('@', inv.MXID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -230,7 +230,7 @@ func createInviteFrom3PIDInvite(
|
||||||
func buildMembershipEvent(
|
func buildMembershipEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
builder *gomatrixserverlib.EventBuilder, queryAPI roomserverAPI.RoomserverQueryAPI,
|
builder *gomatrixserverlib.EventBuilder, queryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
cfg config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
) (*gomatrixserverlib.Event, error) {
|
) (*gomatrixserverlib.Event, error) {
|
||||||
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -290,7 +290,7 @@ func buildMembershipEvent(
|
||||||
// them responded with an error.
|
// them responded with an error.
|
||||||
func sendToRemoteServer(
|
func sendToRemoteServer(
|
||||||
ctx context.Context, inv invite,
|
ctx context.Context, inv invite,
|
||||||
federation *gomatrixserverlib.FederationClient, _ config.Dendrite,
|
federation *gomatrixserverlib.FederationClient, _ *config.Dendrite,
|
||||||
builder gomatrixserverlib.EventBuilder,
|
builder gomatrixserverlib.EventBuilder,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
remoteServers := make([]gomatrixserverlib.ServerName, 2)
|
remoteServers := make([]gomatrixserverlib.ServerName, 2)
|
||||||
|
|
|
||||||
|
|
@ -132,5 +132,5 @@ func joinedHostsFromStmt(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ func (d *Database) UpdateRoom(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastSentEventID != oldEventID {
|
if lastSentEventID != "" && lastSentEventID != oldEventID {
|
||||||
return types.EventIDMismatchError{
|
return types.EventIDMismatchError{
|
||||||
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
|
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
139
federationsender/storage/sqlite3/joined_hosts_table.go
Normal file
139
federationsender/storage/sqlite3/joined_hosts_table.go
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
"github.com/matrix-org/dendrite/federationsender/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const joinedHostsSchema = `
|
||||||
|
-- The joined_hosts table stores a list of m.room.member event ids in the
|
||||||
|
-- current state for each room where the membership is "join".
|
||||||
|
-- There will be an entry for every user that is joined to the room.
|
||||||
|
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
|
||||||
|
-- The string ID of the room.
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
-- The event ID of the m.room.member join event.
|
||||||
|
event_id TEXT NOT NULL,
|
||||||
|
-- The domain part of the user ID the m.room.member event is for.
|
||||||
|
server_name TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
|
||||||
|
ON federationsender_joined_hosts (event_id);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
|
||||||
|
ON federationsender_joined_hosts (room_id)
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertJoinedHostsSQL = "" +
|
||||||
|
"INSERT INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
|
||||||
|
" VALUES ($1, $2, $3)"
|
||||||
|
|
||||||
|
const deleteJoinedHostsSQL = "" +
|
||||||
|
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
|
||||||
|
|
||||||
|
const selectJoinedHostsSQL = "" +
|
||||||
|
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
|
||||||
|
" WHERE room_id = $1"
|
||||||
|
|
||||||
|
type joinedHostsStatements struct {
|
||||||
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(joinedHostsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) insertJoinedHosts(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
roomID, eventID string,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.insertJoinedHostsStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) deleteJoinedHosts(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) error {
|
||||||
|
for _, eventID := range eventIDs {
|
||||||
|
stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt)
|
||||||
|
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) selectJoinedHostsWithTx(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) ([]types.JoinedHost, error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectJoinedHostsStmt)
|
||||||
|
return joinedHostsFromStmt(ctx, stmt, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) selectJoinedHosts(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) ([]types.JoinedHost, error) {
|
||||||
|
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinedHostsFromStmt(
|
||||||
|
ctx context.Context, stmt *sql.Stmt, roomID string,
|
||||||
|
) ([]types.JoinedHost, error) {
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
var result []types.JoinedHost
|
||||||
|
for rows.Next() {
|
||||||
|
var eventID, serverName string
|
||||||
|
if err = rows.Scan(&eventID, &serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, types.JoinedHost{
|
||||||
|
MemberEventID: eventID,
|
||||||
|
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
101
federationsender/storage/sqlite3/room_table.go
Normal file
101
federationsender/storage/sqlite3/room_table.go
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
const roomSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS federationsender_rooms (
|
||||||
|
-- The string ID of the room
|
||||||
|
room_id TEXT PRIMARY KEY,
|
||||||
|
-- The most recent event state by the room server.
|
||||||
|
-- We can use this to tell if our view of the room state has become
|
||||||
|
-- desynchronised.
|
||||||
|
last_event_id TEXT NOT NULL
|
||||||
|
);`
|
||||||
|
|
||||||
|
const insertRoomSQL = "" +
|
||||||
|
"INSERT INTO federationsender_rooms (room_id, last_event_id) VALUES ($1, '')" +
|
||||||
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
|
const selectRoomForUpdateSQL = "" +
|
||||||
|
"SELECT last_event_id FROM federationsender_rooms WHERE room_id = $1"
|
||||||
|
|
||||||
|
const updateRoomSQL = "" +
|
||||||
|
"UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1"
|
||||||
|
|
||||||
|
type roomStatements struct {
|
||||||
|
insertRoomStmt *sql.Stmt
|
||||||
|
selectRoomForUpdateStmt *sql.Stmt
|
||||||
|
updateRoomStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(roomSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.insertRoomStmt, err = db.Prepare(insertRoomSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectRoomForUpdateStmt, err = db.Prepare(selectRoomForUpdateSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.updateRoomStmt, err = db.Prepare(updateRoomSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertRoom inserts the room if it didn't already exist.
|
||||||
|
// If the room didn't exist then last_event_id is set to the empty string.
|
||||||
|
func (s *roomStatements) insertRoom(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) error {
|
||||||
|
_, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
||||||
|
// The row must already exist in the table. Callers can ensure that the row
|
||||||
|
// exists by calling insertRoom first.
|
||||||
|
func (s *roomStatements) selectRoomForUpdate(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) (string, error) {
|
||||||
|
var lastEventID string
|
||||||
|
stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return lastEventID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
|
||||||
|
// have already been called earlier within the transaction.
|
||||||
|
func (s *roomStatements) updateRoom(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.updateRoomStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
124
federationsender/storage/sqlite3/storage.go
Normal file
124
federationsender/storage/sqlite3/storage.go
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
"github.com/matrix-org/dendrite/federationsender/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Database stores information needed by the federation sender
|
||||||
|
type Database struct {
|
||||||
|
joinedHostsStatements
|
||||||
|
roomStatements
|
||||||
|
common.PartitionOffsetStatements
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabase opens a new database
|
||||||
|
func NewDatabase(dataSourceName string) (*Database, error) {
|
||||||
|
var result Database
|
||||||
|
var err error
|
||||||
|
if result.db, err = sql.Open("sqlite3", dataSourceName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = result.prepare(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) prepare() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if err = d.joinedHostsStatements.prepare(d.db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = d.roomStatements.prepare(d.db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.PartitionOffsetStatements.Prepare(d.db, "federationsender")
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRoom updates the joined hosts for a room and returns what the joined
|
||||||
|
// hosts were before the update, or nil if this was a duplicate message.
|
||||||
|
// This is called when we receive a message from kafka, so we pass in
|
||||||
|
// oldEventID and newEventID to check that we haven't missed any messages or
|
||||||
|
// this isn't a duplicate message.
|
||||||
|
func (d *Database) UpdateRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
roomID, oldEventID, newEventID string,
|
||||||
|
addHosts []types.JoinedHost,
|
||||||
|
removeHosts []string,
|
||||||
|
) (joinedHosts []types.JoinedHost, err error) {
|
||||||
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
err = d.insertRoom(ctx, txn, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastSentEventID == newEventID {
|
||||||
|
// We've handled this message before, so let's just ignore it.
|
||||||
|
// We can only get a duplicate for the last message we processed,
|
||||||
|
// so its enough just to compare the newEventID with lastSentEventID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastSentEventID != "" && lastSentEventID != oldEventID {
|
||||||
|
return types.EventIDMismatchError{
|
||||||
|
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, add := range addHosts {
|
||||||
|
err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return d.updateRoom(ctx, txn, roomID, newEventID)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJoinedHosts returns the currently joined hosts for room,
|
||||||
|
// as known to federationserver.
|
||||||
|
// Returns an error if something goes wrong.
|
||||||
|
func (d *Database) GetJoinedHosts(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) ([]types.JoinedHost, error) {
|
||||||
|
return d.selectJoinedHosts(ctx, roomID)
|
||||||
|
}
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/federationsender/storage/postgres"
|
"github.com/matrix-org/dendrite/federationsender/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3"
|
||||||
"github.com/matrix-org/dendrite/federationsender/types"
|
"github.com/matrix-org/dendrite/federationsender/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -36,6 +37,8 @@ func NewDatabase(dataSourceName string) (Database, error) {
|
||||||
return postgres.NewDatabase(dataSourceName)
|
return postgres.NewDatabase(dataSourceName)
|
||||||
}
|
}
|
||||||
switch uri.Scheme {
|
switch uri.Scheme {
|
||||||
|
case "file":
|
||||||
|
return sqlite3.NewDatabase(dataSourceName)
|
||||||
case "postgres":
|
case "postgres":
|
||||||
return postgres.NewDatabase(dataSourceName)
|
return postgres.NewDatabase(dataSourceName)
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
22
go.mod
22
go.mod
|
|
@ -1,32 +1,30 @@
|
||||||
module github.com/matrix-org/dendrite
|
module github.com/matrix-org/dendrite
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3
|
github.com/DataDog/zstd v1.4.4 // indirect
|
||||||
github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect
|
github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect
|
||||||
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
|
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
|
||||||
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42 // indirect
|
github.com/eapache/go-resiliency v1.2.0 // indirect
|
||||||
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934 // indirect
|
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect
|
||||||
github.com/eapache/queue v1.1.0 // indirect
|
github.com/eapache/queue v1.1.0 // indirect
|
||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
github.com/frankban/quicktest v1.7.2 // indirect
|
||||||
|
github.com/golang/snappy v0.0.1 // indirect
|
||||||
github.com/gorilla/mux v1.7.3
|
github.com/gorilla/mux v1.7.3
|
||||||
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 // indirect
|
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||||
github.com/lib/pq v1.2.0
|
github.com/lib/pq v1.2.0
|
||||||
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
|
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5
|
||||||
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0
|
github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1
|
||||||
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5
|
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5
|
||||||
|
github.com/mattn/go-sqlite3 v2.0.2+incompatible
|
||||||
github.com/miekg/dns v1.1.12 // indirect
|
github.com/miekg/dns v1.1.12 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
|
||||||
github.com/modern-go/reflect2 v1.0.1 // indirect
|
|
||||||
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5
|
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5
|
||||||
github.com/opentracing/opentracing-go v1.0.2
|
github.com/opentracing/opentracing-go v1.0.2
|
||||||
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac // indirect
|
github.com/pierrec/lz4 v2.4.1+incompatible // indirect
|
||||||
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897 // indirect
|
|
||||||
github.com/pkg/errors v0.8.1
|
github.com/pkg/errors v0.8.1
|
||||||
github.com/prometheus/client_golang v1.2.1
|
github.com/prometheus/client_golang v1.2.1
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 // indirect
|
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 // indirect
|
||||||
github.com/sirupsen/logrus v1.4.2
|
github.com/sirupsen/logrus v1.4.2
|
||||||
github.com/stretchr/testify v1.4.0 // indirect
|
github.com/stretchr/testify v1.4.0 // indirect
|
||||||
github.com/uber-go/atomic v1.3.0 // indirect
|
github.com/uber-go/atomic v1.3.0 // indirect
|
||||||
|
|
@ -36,7 +34,7 @@ require (
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect
|
||||||
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 // indirect
|
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 // indirect
|
||||||
gopkg.in/Shopify/sarama.v1 v1.11.0
|
gopkg.in/Shopify/sarama.v1 v1.20.1
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||||
gopkg.in/h2non/bimg.v1 v1.0.18
|
gopkg.in/h2non/bimg.v1 v1.0.18
|
||||||
gopkg.in/yaml.v2 v2.2.2
|
gopkg.in/yaml.v2 v2.2.2
|
||||||
|
|
|
||||||
43
go.sum
43
go.sum
|
|
@ -1,5 +1,5 @@
|
||||||
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3 h1:j6BAEHYn1kUyW2j7kY0mOJ/R8A0qWwXpvUAEHGemm/g=
|
github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE=
|
||||||
github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
|
github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
|
||||||
github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc=
|
github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc=
|
||||||
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
|
|
@ -18,13 +18,15 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42 h1:f8ERmXYuaC+kCSv2w+y3rBK/oVu6If4DEm3jywJJ0hc=
|
github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q=
|
||||||
github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
|
github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
|
||||||
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934 h1:oGLoaVIefp3tiOgi7+KInR/nNPvEpPM6GFo+El7fd14=
|
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw=
|
||||||
github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
|
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
|
||||||
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
|
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
|
||||||
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
|
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
|
||||||
github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k=
|
github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k=
|
||||||
|
github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk=
|
||||||
|
github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o=
|
||||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
|
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
|
||||||
|
|
@ -35,11 +37,13 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
|
||||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
|
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
|
||||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
|
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
|
||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||||
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
|
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
|
||||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
|
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
|
||||||
|
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
|
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
|
||||||
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||||
|
|
@ -48,8 +52,6 @@ github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplb
|
||||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||||
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg=
|
|
||||||
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg=
|
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
|
|
@ -69,10 +71,12 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 h1:kmRjpmFOenVpOaV/DRlo9p6z/IbOKlUC+hhKsAAh8Qg=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 h1:kmRjpmFOenVpOaV/DRlo9p6z/IbOKlUC+hhKsAAh8Qg=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI=
|
||||||
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A=
|
github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1 h1:osLoFdOy+ChQqVUn2PeTDETFftVkl4w9t/OW18g3lnk=
|
||||||
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
|
github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
|
||||||
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8=
|
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8=
|
||||||
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA=
|
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA=
|
||||||
|
github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U=
|
||||||
|
github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0=
|
github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0=
|
||||||
|
|
@ -90,10 +94,8 @@ github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/R
|
||||||
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
|
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
|
||||||
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
||||||
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac h1:tKcxwAA5OHUQjL6sWsuCIcP9OnzN+RwKfvomtIOsfy8=
|
github.com/pierrec/lz4 v2.4.1+incompatible h1:mFe7ttWaflA46Mhqh+jUfjp2qTbPYxLB2/OyBppH9dg=
|
||||||
github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
github.com/pierrec/lz4 v2.4.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||||
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897 h1:jp3jc/PyyTrTKjJJ6rWnhTbmo7tGgBFyG9AL5FIrO1I=
|
|
||||||
github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
|
|
||||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
|
@ -114,8 +116,8 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R
|
||||||
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||||
github.com/prometheus/procfs v0.0.5 h1:3+auTFlqw+ZaQYJARz6ArODtkaIwtvBTx3N2NehQlL8=
|
github.com/prometheus/procfs v0.0.5 h1:3+auTFlqw+ZaQYJARz6ArODtkaIwtvBTx3N2NehQlL8=
|
||||||
github.com/prometheus/procfs v0.0.5/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ=
|
github.com/prometheus/procfs v0.0.5/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ=
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 h1:gwcdIpH6NU2iF8CmcqD+CP6+1CkRBOhHaPR+iu6raBY=
|
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 h1:dY6ETXrvDG7Sa4vE8ZQG4yqWg6UnOcbqTAahkV813vQ=
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME=
|
github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME=
|
||||||
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
|
|
@ -172,8 +174,8 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||||
golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY=
|
golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY=
|
||||||
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0=
|
gopkg.in/Shopify/sarama.v1 v1.20.1 h1:Gi09A3fJXm0Jgt8kuKZ8YK+r60GfYn7MQuEmI3oq6hE=
|
||||||
gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
|
gopkg.in/Shopify/sarama.v1 v1.20.1/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
|
||||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|
@ -183,6 +185,7 @@ gopkg.in/h2non/bimg.v1 v1.0.18 h1:qn6/RpBHt+7WQqoBcK+aF2puc6nC78eZj5LexxoalT4=
|
||||||
gopkg.in/h2non/bimg.v1 v1.0.18/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So=
|
gopkg.in/h2non/bimg.v1 v1.0.18/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So=
|
||||||
gopkg.in/h2non/gock.v1 v1.0.14 h1:fTeu9fcUvSnLNacYvYI54h+1/XEteDyHvrVCZEEEYNM=
|
gopkg.in/h2non/gock.v1 v1.0.14 h1:fTeu9fcUvSnLNacYvYI54h+1/XEteDyHvrVCZEEEYNM=
|
||||||
gopkg.in/h2non/gock.v1 v1.0.14/go.mod h1:sX4zAkdYX1TRGJ2JY156cFspQn4yRWn6p9EMdODlynE=
|
gopkg.in/h2non/gock.v1 v1.0.14/go.mod h1:sX4zAkdYX1TRGJ2JY156cFspQn4yRWn6p9EMdODlynE=
|
||||||
|
gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI=
|
||||||
gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o=
|
gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o=
|
||||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ import (
|
||||||
// component.
|
// component.
|
||||||
func SetupMediaAPIComponent(
|
func SetupMediaAPIComponent(
|
||||||
base *basecomponent.BaseDendrite,
|
base *basecomponent.BaseDendrite,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
) {
|
) {
|
||||||
mediaDB, err := storage.Open(string(base.Cfg.Database.MediaAPI))
|
mediaDB, err := storage.Open(string(base.Cfg.Database.MediaAPI))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ func Setup(
|
||||||
apiMux *mux.Router,
|
apiMux *mux.Router,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
db storage.Database,
|
db storage.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
client *gomatrixserverlib.Client,
|
client *gomatrixserverlib.Client,
|
||||||
) {
|
) {
|
||||||
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
|
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,7 @@ func (s *thumbnailStatements) selectThumbnails(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
var thumbnails []*types.ThumbnailMetadata
|
var thumbnails []*types.ThumbnailMetadata
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
@ -167,5 +168,5 @@ func (s *thumbnailStatements) selectThumbnails(
|
||||||
thumbnails = append(thumbnails, &thumbnailMetadata)
|
thumbnails = append(thumbnails, &thumbnailMetadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
return thumbnails, err
|
return thumbnails, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ import (
|
||||||
// component.
|
// component.
|
||||||
func SetupPublicRoomsAPIComponent(
|
func SetupPublicRoomsAPIComponent(
|
||||||
base *basecomponent.BaseDendrite,
|
base *basecomponent.BaseDendrite,
|
||||||
deviceDB *devices.Database,
|
deviceDB devices.Database,
|
||||||
rsQueryAPI roomserverAPI.RoomserverQueryAPI,
|
rsQueryAPI roomserverAPI.RoomserverQueryAPI,
|
||||||
) {
|
) {
|
||||||
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))
|
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
|
||||||
// Due to Setup being used to call many other functions, a gocyclo nolint is
|
// Due to Setup being used to call many other functions, a gocyclo nolint is
|
||||||
// applied:
|
// applied:
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB storage.Database) {
|
func Setup(apiMux *mux.Router, deviceDB devices.Database, publicRoomsDB storage.Database) {
|
||||||
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
|
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
|
||||||
|
|
||||||
authData := auth.Data{
|
authData := auth.Data{
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,7 @@ func (s *publicRoomsStatements) selectPublicRooms(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []types.PublicRoom{}, nil
|
return []types.PublicRoom{}, nil
|
||||||
}
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
rooms := []types.PublicRoom{}
|
rooms := []types.PublicRoom{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
@ -222,7 +223,7 @@ func (s *publicRoomsStatements) selectPublicRooms(
|
||||||
rooms = append(rooms, r)
|
rooms = append(rooms, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rooms, nil
|
return rooms, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *publicRoomsStatements) selectRoomVisibility(
|
func (s *publicRoomsStatements) selectRoomVisibility(
|
||||||
|
|
|
||||||
36
publicroomsapi/storage/sqlite3/prepare.go
Normal file
36
publicroomsapi/storage/sqlite3/prepare.go
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
||||||
|
type statementList []struct {
|
||||||
|
statement **sql.Stmt
|
||||||
|
sql string
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
||||||
|
func (s statementList) prepare(db *sql.DB) (err error) {
|
||||||
|
for _, statement := range s {
|
||||||
|
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
277
publicroomsapi/storage/sqlite3/public_rooms_table.go
Normal file
277
publicroomsapi/storage/sqlite3/public_rooms_table.go
Normal file
|
|
@ -0,0 +1,277 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/publicroomsapi/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
var editableAttributes = []string{
|
||||||
|
"aliases",
|
||||||
|
"canonical_alias",
|
||||||
|
"name",
|
||||||
|
"topic",
|
||||||
|
"world_readable",
|
||||||
|
"guest_can_join",
|
||||||
|
"avatar_url",
|
||||||
|
"visibility",
|
||||||
|
}
|
||||||
|
|
||||||
|
const publicRoomsSchema = `
|
||||||
|
-- Stores all of the rooms with data needed to create the server's room directory
|
||||||
|
CREATE TABLE IF NOT EXISTS publicroomsapi_public_rooms(
|
||||||
|
-- The room's ID
|
||||||
|
room_id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- Number of joined members in the room
|
||||||
|
joined_members INTEGER NOT NULL DEFAULT 0,
|
||||||
|
-- Aliases of the room (empty array if none)
|
||||||
|
aliases TEXT[] NOT NULL DEFAULT '{}'::TEXT[],
|
||||||
|
-- Canonical alias of the room (empty string if none)
|
||||||
|
canonical_alias TEXT NOT NULL DEFAULT '',
|
||||||
|
-- Name of the room (empty string if none)
|
||||||
|
name TEXT NOT NULL DEFAULT '',
|
||||||
|
-- Topic of the room (empty string if none)
|
||||||
|
topic TEXT NOT NULL DEFAULT '',
|
||||||
|
-- Is the room world readable?
|
||||||
|
world_readable BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
-- Can guest join the room?
|
||||||
|
guest_can_join BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
-- URL of the room avatar (empty string if none)
|
||||||
|
avatar_url TEXT NOT NULL DEFAULT '',
|
||||||
|
-- Visibility of the room: true means the room is publicly visible, false
|
||||||
|
-- means the room is private
|
||||||
|
visibility BOOLEAN NOT NULL DEFAULT false
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const countPublicRoomsSQL = "" +
|
||||||
|
"SELECT COUNT(*) FROM publicroomsapi_public_rooms" +
|
||||||
|
" WHERE visibility = true"
|
||||||
|
|
||||||
|
const selectPublicRoomsSQL = "" +
|
||||||
|
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
|
||||||
|
" FROM publicroomsapi_public_rooms WHERE visibility = true" +
|
||||||
|
" ORDER BY joined_members DESC" +
|
||||||
|
" OFFSET $1"
|
||||||
|
|
||||||
|
const selectPublicRoomsWithLimitSQL = "" +
|
||||||
|
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
|
||||||
|
" FROM publicroomsapi_public_rooms WHERE visibility = true" +
|
||||||
|
" ORDER BY joined_members DESC" +
|
||||||
|
" OFFSET $1 LIMIT $2"
|
||||||
|
|
||||||
|
const selectPublicRoomsWithFilterSQL = "" +
|
||||||
|
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
|
||||||
|
" FROM publicroomsapi_public_rooms" +
|
||||||
|
" WHERE visibility = true" +
|
||||||
|
" AND (LOWER(name) LIKE LOWER($1)" +
|
||||||
|
" OR LOWER(topic) LIKE LOWER($1)" +
|
||||||
|
" OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" +
|
||||||
|
" ORDER BY joined_members DESC" +
|
||||||
|
" OFFSET $2"
|
||||||
|
|
||||||
|
const selectPublicRoomsWithLimitAndFilterSQL = "" +
|
||||||
|
"SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
|
||||||
|
" FROM publicroomsapi_public_rooms" +
|
||||||
|
" WHERE visibility = true" +
|
||||||
|
" AND (LOWER(name) LIKE LOWER($1)" +
|
||||||
|
" OR LOWER(topic) LIKE LOWER($1)" +
|
||||||
|
" OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" +
|
||||||
|
" ORDER BY joined_members DESC" +
|
||||||
|
" OFFSET $2 LIMIT $3"
|
||||||
|
|
||||||
|
const selectRoomVisibilitySQL = "" +
|
||||||
|
"SELECT visibility FROM publicroomsapi_public_rooms" +
|
||||||
|
" WHERE room_id = $1"
|
||||||
|
|
||||||
|
const insertNewRoomSQL = "" +
|
||||||
|
"INSERT INTO publicroomsapi_public_rooms(room_id)" +
|
||||||
|
" VALUES ($1)"
|
||||||
|
|
||||||
|
const incrementJoinedMembersInRoomSQL = "" +
|
||||||
|
"UPDATE publicroomsapi_public_rooms" +
|
||||||
|
" SET joined_members = joined_members + 1" +
|
||||||
|
" WHERE room_id = $1"
|
||||||
|
|
||||||
|
const decrementJoinedMembersInRoomSQL = "" +
|
||||||
|
"UPDATE publicroomsapi_public_rooms" +
|
||||||
|
" SET joined_members = joined_members - 1" +
|
||||||
|
" WHERE room_id = $1"
|
||||||
|
|
||||||
|
const updateRoomAttributeSQL = "" +
|
||||||
|
"UPDATE publicroomsapi_public_rooms" +
|
||||||
|
" SET %s = $1" +
|
||||||
|
" WHERE room_id = $2"
|
||||||
|
|
||||||
|
type publicRoomsStatements struct {
|
||||||
|
countPublicRoomsStmt *sql.Stmt
|
||||||
|
selectPublicRoomsStmt *sql.Stmt
|
||||||
|
selectPublicRoomsWithLimitStmt *sql.Stmt
|
||||||
|
selectPublicRoomsWithFilterStmt *sql.Stmt
|
||||||
|
selectPublicRoomsWithLimitAndFilterStmt *sql.Stmt
|
||||||
|
selectRoomVisibilityStmt *sql.Stmt
|
||||||
|
insertNewRoomStmt *sql.Stmt
|
||||||
|
incrementJoinedMembersInRoomStmt *sql.Stmt
|
||||||
|
decrementJoinedMembersInRoomStmt *sql.Stmt
|
||||||
|
updateRoomAttributeStmts map[string]*sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *publicRoomsStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(publicRoomsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stmts := statementList{
|
||||||
|
{&s.countPublicRoomsStmt, countPublicRoomsSQL},
|
||||||
|
{&s.selectPublicRoomsStmt, selectPublicRoomsSQL},
|
||||||
|
{&s.selectPublicRoomsWithLimitStmt, selectPublicRoomsWithLimitSQL},
|
||||||
|
{&s.selectPublicRoomsWithFilterStmt, selectPublicRoomsWithFilterSQL},
|
||||||
|
{&s.selectPublicRoomsWithLimitAndFilterStmt, selectPublicRoomsWithLimitAndFilterSQL},
|
||||||
|
{&s.selectRoomVisibilityStmt, selectRoomVisibilitySQL},
|
||||||
|
{&s.insertNewRoomStmt, insertNewRoomSQL},
|
||||||
|
{&s.incrementJoinedMembersInRoomStmt, incrementJoinedMembersInRoomSQL},
|
||||||
|
{&s.decrementJoinedMembersInRoomStmt, decrementJoinedMembersInRoomSQL},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = stmts.prepare(db); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.updateRoomAttributeStmts = make(map[string]*sql.Stmt)
|
||||||
|
for _, editable := range editableAttributes {
|
||||||
|
stmt := fmt.Sprintf(updateRoomAttributeSQL, editable)
|
||||||
|
if s.updateRoomAttributeStmts[editable], err = db.Prepare(stmt); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *publicRoomsStatements) countPublicRooms(ctx context.Context) (nb int64, err error) {
|
||||||
|
err = s.countPublicRoomsStmt.QueryRowContext(ctx).Scan(&nb)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *publicRoomsStatements) selectPublicRooms(
|
||||||
|
ctx context.Context, offset int64, limit int16, filter string,
|
||||||
|
) ([]types.PublicRoom, error) {
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if len(filter) > 0 {
|
||||||
|
pattern := "%" + filter + "%"
|
||||||
|
if limit == 0 {
|
||||||
|
rows, err = s.selectPublicRoomsWithFilterStmt.QueryContext(
|
||||||
|
ctx, pattern, offset,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
rows, err = s.selectPublicRoomsWithLimitAndFilterStmt.QueryContext(
|
||||||
|
ctx, pattern, offset, limit,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if limit == 0 {
|
||||||
|
rows, err = s.selectPublicRoomsStmt.QueryContext(ctx, offset)
|
||||||
|
} else {
|
||||||
|
rows, err = s.selectPublicRoomsWithLimitStmt.QueryContext(
|
||||||
|
ctx, offset, limit,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return []types.PublicRoom{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rooms := []types.PublicRoom{}
|
||||||
|
for rows.Next() {
|
||||||
|
var r types.PublicRoom
|
||||||
|
var aliases pq.StringArray
|
||||||
|
|
||||||
|
err = rows.Scan(
|
||||||
|
&r.RoomID, &r.NumJoinedMembers, &aliases, &r.CanonicalAlias,
|
||||||
|
&r.Name, &r.Topic, &r.WorldReadable, &r.GuestCanJoin, &r.AvatarURL,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return rooms, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Aliases = aliases
|
||||||
|
|
||||||
|
rooms = append(rooms, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rooms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *publicRoomsStatements) selectRoomVisibility(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) (v bool, err error) {
|
||||||
|
err = s.selectRoomVisibilityStmt.QueryRowContext(ctx, roomID).Scan(&v)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *publicRoomsStatements) insertNewRoom(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) error {
|
||||||
|
_, err := s.insertNewRoomStmt.ExecContext(ctx, roomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *publicRoomsStatements) incrementJoinedMembersInRoom(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) error {
|
||||||
|
_, err := s.incrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *publicRoomsStatements) decrementJoinedMembersInRoom(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) error {
|
||||||
|
_, err := s.decrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *publicRoomsStatements) updateRoomAttribute(
|
||||||
|
ctx context.Context, attrName string, attrValue attributeValue, roomID string,
|
||||||
|
) error {
|
||||||
|
stmt, isEditable := s.updateRoomAttributeStmts[attrName]
|
||||||
|
|
||||||
|
if !isEditable {
|
||||||
|
return errors.New("Cannot edit " + attrName)
|
||||||
|
}
|
||||||
|
|
||||||
|
var value interface{}
|
||||||
|
switch v := attrValue.(type) {
|
||||||
|
case []string:
|
||||||
|
value = pq.StringArray(v)
|
||||||
|
case bool, string:
|
||||||
|
value = attrValue
|
||||||
|
default:
|
||||||
|
return errors.New("Unsupported attribute type, must be bool, string or []string")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := stmt.ExecContext(ctx, value, roomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
256
publicroomsapi/storage/sqlite3/storage.go
Normal file
256
publicroomsapi/storage/sqlite3/storage.go
Normal file
|
|
@ -0,0 +1,256 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
"github.com/matrix-org/dendrite/publicroomsapi/types"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PublicRoomsServerDatabase represents a public rooms server database.
|
||||||
|
type PublicRoomsServerDatabase struct {
|
||||||
|
db *sql.DB
|
||||||
|
common.PartitionOffsetStatements
|
||||||
|
statements publicRoomsStatements
|
||||||
|
}
|
||||||
|
|
||||||
|
type attributeValue interface{}
|
||||||
|
|
||||||
|
// NewPublicRoomsServerDatabase creates a new public rooms server database.
|
||||||
|
func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) {
|
||||||
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
|
if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
storage := PublicRoomsServerDatabase{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = storage.statements.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &storage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoomVisibility returns the room visibility as a boolean: true if the room
|
||||||
|
// is publicly visible, false if not.
|
||||||
|
// Returns an error if the retrieval failed.
|
||||||
|
func (d *PublicRoomsServerDatabase) GetRoomVisibility(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) (bool, error) {
|
||||||
|
return d.statements.selectRoomVisibility(ctx, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRoomVisibility updates the visibility attribute of a room. This attribute
|
||||||
|
// must be set to true if the room is publicly visible, false if not.
|
||||||
|
// Returns an error if the update failed.
|
||||||
|
func (d *PublicRoomsServerDatabase) SetRoomVisibility(
|
||||||
|
ctx context.Context, visible bool, roomID string,
|
||||||
|
) error {
|
||||||
|
return d.statements.updateRoomAttribute(ctx, "visibility", visible, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountPublicRooms returns the number of room set as publicly visible on the server.
|
||||||
|
// Returns an error if the retrieval failed.
|
||||||
|
func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) {
|
||||||
|
return d.statements.countPublicRooms(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number
|
||||||
|
// of joined members. This array can be limited by a given number of elements, and offset by a given value.
|
||||||
|
// If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all
|
||||||
|
// the rooms set as publicly visible on the server.
|
||||||
|
// Returns an error if the retrieval failed.
|
||||||
|
func (d *PublicRoomsServerDatabase) GetPublicRooms(
|
||||||
|
ctx context.Context, offset int64, limit int16, filter string,
|
||||||
|
) ([]types.PublicRoom, error) {
|
||||||
|
return d.statements.selectPublicRooms(ctx, offset, limit, filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRoomFromEvents iterate over a slice of state events and call
|
||||||
|
// UpdateRoomFromEvent on each of them to update the database representation of
|
||||||
|
// the rooms updated by each event.
|
||||||
|
// The slice of events to remove is used to update the number of joined members
|
||||||
|
// for the room in the database.
|
||||||
|
// If the update triggered by one of the events failed, aborts the process and
|
||||||
|
// returns an error.
|
||||||
|
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents(
|
||||||
|
ctx context.Context,
|
||||||
|
eventsToAdd []gomatrixserverlib.Event,
|
||||||
|
eventsToRemove []gomatrixserverlib.Event,
|
||||||
|
) error {
|
||||||
|
for _, event := range eventsToAdd {
|
||||||
|
if err := d.UpdateRoomFromEvent(ctx, event); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range eventsToRemove {
|
||||||
|
if event.Type() == "m.room.member" {
|
||||||
|
if err := d.updateNumJoinedUsers(ctx, event, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRoomFromEvent updates the database representation of a room from a Matrix event, by
|
||||||
|
// checking the event's type to know which attribute to change and using the event's content
|
||||||
|
// to define the new value of the attribute.
|
||||||
|
// If the event doesn't match with any property used to compute the public room directory,
|
||||||
|
// does nothing.
|
||||||
|
// If something went wrong during the process, returns an error.
|
||||||
|
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(
|
||||||
|
ctx context.Context, event gomatrixserverlib.Event,
|
||||||
|
) error {
|
||||||
|
// Process the event according to its type
|
||||||
|
switch event.Type() {
|
||||||
|
case "m.room.create":
|
||||||
|
return d.statements.insertNewRoom(ctx, event.RoomID())
|
||||||
|
case "m.room.member":
|
||||||
|
return d.updateNumJoinedUsers(ctx, event, false)
|
||||||
|
case "m.room.aliases":
|
||||||
|
return d.updateRoomAliases(ctx, event)
|
||||||
|
case "m.room.canonical_alias":
|
||||||
|
var content common.CanonicalAliasContent
|
||||||
|
field := &(content.Alias)
|
||||||
|
attrName := "canonical_alias"
|
||||||
|
return d.updateStringAttribute(ctx, attrName, event, &content, field)
|
||||||
|
case "m.room.name":
|
||||||
|
var content common.NameContent
|
||||||
|
field := &(content.Name)
|
||||||
|
attrName := "name"
|
||||||
|
return d.updateStringAttribute(ctx, attrName, event, &content, field)
|
||||||
|
case "m.room.topic":
|
||||||
|
var content common.TopicContent
|
||||||
|
field := &(content.Topic)
|
||||||
|
attrName := "topic"
|
||||||
|
return d.updateStringAttribute(ctx, attrName, event, &content, field)
|
||||||
|
case "m.room.avatar":
|
||||||
|
var content common.AvatarContent
|
||||||
|
field := &(content.URL)
|
||||||
|
attrName := "avatar_url"
|
||||||
|
return d.updateStringAttribute(ctx, attrName, event, &content, field)
|
||||||
|
case "m.room.history_visibility":
|
||||||
|
var content common.HistoryVisibilityContent
|
||||||
|
field := &(content.HistoryVisibility)
|
||||||
|
attrName := "world_readable"
|
||||||
|
strForTrue := "world_readable"
|
||||||
|
return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
|
||||||
|
case "m.room.guest_access":
|
||||||
|
var content common.GuestAccessContent
|
||||||
|
field := &(content.GuestAccess)
|
||||||
|
attrName := "guest_can_join"
|
||||||
|
strForTrue := "can_join"
|
||||||
|
return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the event type didn't match, return with no error
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateNumJoinedUsers updates the number of joined user in the database representation
|
||||||
|
// of a room using a given "m.room.member" Matrix event.
|
||||||
|
// If the membership property of the event isn't "join", ignores it and returs nil.
|
||||||
|
// If the remove parameter is set to false, increments the joined members counter in the
|
||||||
|
// database, if set to truem decrements it.
|
||||||
|
// Returns an error if the update failed.
|
||||||
|
func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
|
||||||
|
ctx context.Context, membershipEvent gomatrixserverlib.Event, remove bool,
|
||||||
|
) error {
|
||||||
|
membership, err := membershipEvent.Membership()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if membership != gomatrixserverlib.Join {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if remove {
|
||||||
|
return d.statements.decrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
|
||||||
|
}
|
||||||
|
return d.statements.incrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateStringAttribute updates a given string attribute in the database
|
||||||
|
// representation of a room using a given string data field from content of the
|
||||||
|
// Matrix event triggering the update.
|
||||||
|
// Returns an error if decoding the Matrix event's content or updating the attribute
|
||||||
|
// failed.
|
||||||
|
func (d *PublicRoomsServerDatabase) updateStringAttribute(
|
||||||
|
ctx context.Context, attrName string, event gomatrixserverlib.Event,
|
||||||
|
content interface{}, field *string,
|
||||||
|
) error {
|
||||||
|
if err := json.Unmarshal(event.Content(), content); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.statements.updateRoomAttribute(ctx, attrName, *field, event.RoomID())
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateBooleanAttribute updates a given boolean attribute in the database
|
||||||
|
// representation of a room using a given string data field from content of the
|
||||||
|
// Matrix event triggering the update.
|
||||||
|
// The attribute is set to true if the field matches a given string, false if not.
|
||||||
|
// Returns an error if decoding the Matrix event's content or updating the attribute
|
||||||
|
// failed.
|
||||||
|
func (d *PublicRoomsServerDatabase) updateBooleanAttribute(
|
||||||
|
ctx context.Context, attrName string, event gomatrixserverlib.Event,
|
||||||
|
content interface{}, field *string, strForTrue string,
|
||||||
|
) error {
|
||||||
|
if err := json.Unmarshal(event.Content(), content); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var attrValue bool
|
||||||
|
if *field == strForTrue {
|
||||||
|
attrValue = true
|
||||||
|
} else {
|
||||||
|
attrValue = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.statements.updateRoomAttribute(ctx, attrName, attrValue, event.RoomID())
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of
|
||||||
|
// a given room with it.
|
||||||
|
// Returns an error if decoding the Matrix event or updating the list failed.
|
||||||
|
func (d *PublicRoomsServerDatabase) updateRoomAliases(
|
||||||
|
ctx context.Context, aliasesEvent gomatrixserverlib.Event,
|
||||||
|
) error {
|
||||||
|
var content common.AliasesContent
|
||||||
|
if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.statements.updateRoomAttribute(
|
||||||
|
ctx, "aliases", content.Aliases, aliasesEvent.RoomID(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
@ -196,7 +196,12 @@ func processInviteEvent(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
succeeded := false
|
succeeded := false
|
||||||
defer common.EndTransaction(updater, &succeeded)
|
defer func() {
|
||||||
|
txerr := common.EndTransaction(updater, &succeeded)
|
||||||
|
if err == nil && txerr != nil {
|
||||||
|
err = txerr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if updater.IsJoin() {
|
if updater.IsJoin() {
|
||||||
// If the user is joined to the room then that takes precedence over this
|
// If the user is joined to the room then that takes precedence over this
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,12 @@ func updateLatestEvents(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
succeeded := false
|
succeeded := false
|
||||||
defer common.EndTransaction(updater, &succeeded)
|
defer func() {
|
||||||
|
txerr := common.EndTransaction(updater, &succeeded)
|
||||||
|
if err == nil && txerr != nil {
|
||||||
|
err = txerr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
u := latestEventsUpdater{
|
u := latestEventsUpdater{
|
||||||
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
|
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
|
||||||
|
|
|
||||||
|
|
@ -102,5 +102,5 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
|
||||||
}
|
}
|
||||||
result.EventNID = types.EventNID(eventNID)
|
result.EventNID = types.EventNID(eventNID)
|
||||||
}
|
}
|
||||||
return results[:i], nil
|
return results[:i], rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||||
}
|
}
|
||||||
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
|
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
||||||
|
|
@ -150,5 +150,5 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
||||||
}
|
}
|
||||||
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
|
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -143,5 +143,5 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
|
||||||
}
|
}
|
||||||
result[eventType] = types.EventTypeNID(eventTypeNID)
|
result[eventType] = types.EventTypeNID(eventTypeNID)
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -209,6 +209,9 @@ func (s *eventStatements) bulkSelectStateEventByID(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if i != len(eventIDs) {
|
if i != len(eventIDs) {
|
||||||
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||||
// We don't know which ones were missing because we don't return the string IDs in the query.
|
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||||
|
|
@ -219,7 +222,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
|
||||||
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
|
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return results, err
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||||
|
|
@ -251,12 +254,15 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if i != len(eventIDs) {
|
if i != len(eventIDs) {
|
||||||
return nil, types.MissingEventError(
|
return nil, types.MissingEventError(
|
||||||
fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
|
fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return results, err
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) updateEventState(
|
func (s *eventStatements) updateEventState(
|
||||||
|
|
@ -321,6 +327,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
||||||
result.EventID = eventID
|
result.EventID = eventID
|
||||||
result.EventSHA256 = eventSHA256
|
result.EventSHA256 = eventSHA256
|
||||||
}
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if i != len(eventNIDs) {
|
if i != len(eventNIDs) {
|
||||||
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||||
}
|
}
|
||||||
|
|
@ -343,6 +352,9 @@ func (s *eventStatements) bulkSelectEventReference(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if i != len(eventNIDs) {
|
if i != len(eventNIDs) {
|
||||||
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||||
}
|
}
|
||||||
|
|
@ -366,6 +378,9 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ
|
||||||
}
|
}
|
||||||
results[types.EventNID(eventNID)] = eventID
|
results[types.EventNID(eventNID)] = eventID
|
||||||
}
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if i != len(eventNIDs) {
|
if i != len(eventNIDs) {
|
||||||
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||||
}
|
}
|
||||||
|
|
@ -389,7 +404,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str
|
||||||
}
|
}
|
||||||
results[eventID] = types.EventNID(eventNID)
|
results[eventID] = types.EventNID(eventNID)
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {
|
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {
|
||||||
|
|
|
||||||
|
|
@ -114,21 +114,23 @@ func (s *inviteStatements) insertInviteEvent(
|
||||||
func (s *inviteStatements) updateInviteRetired(
|
func (s *inviteStatements) updateInviteRetired(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (eventIDs []string, err error) {
|
) ([]string, error) {
|
||||||
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
|
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer (func() { err = rows.Close() })()
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
var eventIDs []string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var inviteEventID string
|
var inviteEventID string
|
||||||
if err := rows.Scan(&inviteEventID); err != nil {
|
if err = rows.Scan(&inviteEventID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventIDs = append(eventIDs, inviteEventID)
|
eventIDs = append(eventIDs, inviteEventID)
|
||||||
}
|
}
|
||||||
return
|
return eventIDs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||||
|
|
@ -151,5 +153,5 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom(
|
||||||
}
|
}
|
||||||
result = append(result, types.EventStateKeyNID(senderUserNID))
|
result = append(result, types.EventStateKeyNID(senderUserNID))
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,7 @@ func (s *membershipStatements) selectMembershipsFromRoom(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var eNID types.EventNID
|
var eNID types.EventNID
|
||||||
|
|
@ -159,8 +160,9 @@ func (s *membershipStatements) selectMembershipsFromRoom(
|
||||||
}
|
}
|
||||||
eventNIDs = append(eventNIDs, eNID)
|
eventNIDs = append(eventNIDs, eNID)
|
||||||
}
|
}
|
||||||
return
|
return eventNIDs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID, membership membershipState,
|
roomNID types.RoomNID, membership membershipState,
|
||||||
|
|
@ -170,6 +172,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var eNID types.EventNID
|
var eNID types.EventNID
|
||||||
|
|
@ -178,7 +181,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||||
}
|
}
|
||||||
eventNIDs = append(eventNIDs, eNID)
|
eventNIDs = append(eventNIDs, eNID)
|
||||||
}
|
}
|
||||||
return
|
return eventNIDs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) updateMembership(
|
func (s *membershipStatements) updateMembership(
|
||||||
|
|
|
||||||
|
|
@ -90,23 +90,23 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias(
|
||||||
|
|
||||||
func (s *roomAliasesStatements) selectAliasesFromRoomID(
|
func (s *roomAliasesStatements) selectAliasesFromRoomID(
|
||||||
ctx context.Context, roomID string,
|
ctx context.Context, roomID string,
|
||||||
) (aliases []string, err error) {
|
) ([]string, error) {
|
||||||
aliases = []string{}
|
|
||||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
var aliases []string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var alias string
|
var alias string
|
||||||
if err = rows.Scan(&alias); err != nil {
|
if err = rows.Scan(&alias); err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aliases = append(aliases, alias)
|
aliases = append(aliases, alias)
|
||||||
}
|
}
|
||||||
|
return aliases, rows.Err()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) selectCreatorIDFromAlias(
|
func (s *roomAliasesStatements) selectCreatorIDFromAlias(
|
||||||
|
|
|
||||||
|
|
@ -152,7 +152,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
|
||||||
eventNID int64
|
eventNID int64
|
||||||
entry types.StateEntry
|
entry types.StateEntry
|
||||||
)
|
)
|
||||||
if err := rows.Scan(
|
if err = rows.Scan(
|
||||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -169,10 +169,13 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
|
||||||
}
|
}
|
||||||
current.StateEntries = append(current.StateEntries, entry)
|
current.StateEntries = append(current.StateEntries, entry)
|
||||||
}
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if i != len(stateBlockNIDs) {
|
if i != len(stateBlockNIDs) {
|
||||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
|
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
||||||
|
|
@ -237,7 +240,7 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
||||||
if current.StateEntries != nil {
|
if current.StateEntries != nil {
|
||||||
results = append(results, current)
|
results = append(results, current)
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
|
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
||||||
for ; rows.Next(); i++ {
|
for ; rows.Next(); i++ {
|
||||||
result := &results[i]
|
result := &results[i]
|
||||||
var stateBlockNIDs pq.Int64Array
|
var stateBlockNIDs pq.Int64Array
|
||||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
|
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
|
||||||
|
|
@ -112,6 +112,9 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
||||||
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
|
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if i != len(stateNIDs) {
|
if i != len(stateNIDs) {
|
||||||
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
|
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue