// Copyright 2018 Vector Creations Ltd // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package storage import ( "context" "database/sql" "strings" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/encryptoapi/types" ) // Database represents a presence database. type Database struct { db *sql.DB keyStatements keyStatements alStatements alStatements } // NewDatabase creates a new presence database func NewDatabase(dataSourceName string) (*Database, error) { var db *sql.DB var err error if db, err = sql.Open("postgres", dataSourceName); err != nil { return nil, err } keyStatement := keyStatements{} alStatement := alStatements{} if err = keyStatement.prepare(db); err != nil { return nil, err } if err = alStatement.prepare(db); err != nil { return nil, err } return &Database{db: db, keyStatements: keyStatement, alStatements: alStatement}, nil } // InsertKey insert device key func (d *Database) InsertKey( ctx context.Context, deviceID, userID, keyID, keyTyp, keyInfo, al, sig string, ) (err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { return d.keyStatements.insertKey(ctx, txn, deviceID, userID, keyID, keyTyp, keyInfo, al, sig) }) return } // SelectOneTimeKeyCount provides the number of un-claimed OTKeys func (d *Database) SelectOneTimeKeyCount( ctx context.Context, deviceID, userID string, ) (m map[string]int, err error) { m = make(map[string]int) err = common.WithTransaction(d.db, func(txn *sql.Tx) error { elems, err := d.keyStatements.selectKey(ctx, txn, deviceID, userID) for _, val := range elems { if _, ok := m[val.KeyAlgorithm]; !ok { m[val.KeyAlgorithm] = 0 } if val.KeyType == "one_time_key" { m[val.KeyAlgorithm]++ } } return err }) return } // QueryInRange query keys in a range of devices func (d *Database) QueryInRange( ctx context.Context, userID string, arr []string, ) (res []types.KeyHolder, err error) { res, err = d.keyStatements.selectInKeys(ctx, userID, arr) return } // InsertAl persist algorithms func (d *Database) InsertAl( ctx context.Context, uid, device string, al []string, ) (err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) { err = d.alStatements.insertAl(ctx, txn, uid, device, strings.Join(al, ",")) return }) return } // SelectAl select algorithms func (d *Database) SelectAlgo( ctx context.Context, uid, device string, ) (res []string, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) { holder, err := d.alStatements.selectAl(ctx, txn, uid, device) res = strings.Split(holder.SupportedAlgorithm, ",") return }) return } // SelectOneTimeKeySingle claim for one time key one for once func (d *Database) SelectOneTimeKeySingle( ctx context.Context, userID, deviceID, algorithm string, ) (holder types.KeyHolder, err error) { holder, err = d.keyStatements.selectSingleKey(ctx, userID, deviceID, algorithm) return } // SyncOneTimeCount for sync device_one_time_keys_count extension func (d *Database) SyncOneTimeCount( ctx context.Context, userID, deviceID string, ) (holder map[string]int, err error) { holder, err = d.keyStatements.selectOneTimeKeyCount(ctx, userID, deviceID) return } // InsertChanges inserts the changes to the DB func (d *Database) InsertChanges( ctx context.Context, readID int, userID, status string, ) error { err := common.WithTransaction(d.db, func(txn *sql.Tx) error { return d.keyStatements.insertChanges(ctx, txn, readID, userID, status) }) return err } // GetKeyChanges gets the changed keys from the DB func (d *Database) GetKeyChanges( ctx context.Context, readID int, userID string, ) (keyChanges types.KeyChanges, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) { keyChanges, err = d.keyStatements.selectChanges(ctx, txn, readID, userID) return }) return }