mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-16 11:23:11 -06:00
156 lines
4.3 KiB
Go
156 lines
4.3 KiB
Go
// Copyright 2019 Sumukha PK
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"strings"
|
|
|
|
"github.com/matrix-org/dendrite/common"
|
|
"github.com/matrix-org/dendrite/encryptoapi/types"
|
|
)
|
|
|
|
// Database represents a presence database.
|
|
type Database struct {
|
|
db *sql.DB
|
|
keyStatements keyStatements
|
|
alStatements alStatements
|
|
}
|
|
|
|
// NewDatabase creates a new presence database
|
|
func NewDatabase(dataSourceName string) (*Database, error) {
|
|
var db *sql.DB
|
|
var err error
|
|
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
|
return nil, err
|
|
}
|
|
keyStatement := keyStatements{}
|
|
alStatement := alStatements{}
|
|
if err = keyStatement.prepare(db); err != nil {
|
|
return nil, err
|
|
}
|
|
if err = alStatement.prepare(db); err != nil {
|
|
return nil, err
|
|
}
|
|
return &Database{db: db, keyStatements: keyStatement, alStatements: alStatement}, nil
|
|
}
|
|
|
|
// InsertKey insert device key
|
|
func (d *Database) InsertKey(
|
|
ctx context.Context,
|
|
deviceID, userID, keyID, keyTyp, keyInfo, al, sig string,
|
|
) (err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
return d.keyStatements.insertKey(ctx, txn, deviceID, userID, keyID, keyTyp, keyInfo, al, sig)
|
|
})
|
|
return
|
|
}
|
|
|
|
// SelectOneTimeKeyCount provides the number of un-claimed OTKeys
|
|
func (d *Database) SelectOneTimeKeyCount(
|
|
ctx context.Context,
|
|
deviceID, userID string,
|
|
) (m map[string]int, err error) {
|
|
m = make(map[string]int)
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
elems, err := d.keyStatements.selectKey(ctx, txn, deviceID, userID)
|
|
for _, val := range elems {
|
|
if _, ok := m[val.KeyAlgorithm]; !ok {
|
|
m[val.KeyAlgorithm] = 0
|
|
}
|
|
if val.KeyType == "one_time_key" {
|
|
m[val.KeyAlgorithm]++
|
|
}
|
|
}
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// QueryInRange query keys in a range of devices
|
|
func (d *Database) QueryInRange(
|
|
ctx context.Context,
|
|
userID string,
|
|
arr []string,
|
|
) (res []types.KeyHolder, err error) {
|
|
res, err = d.keyStatements.selectInKeys(ctx, userID, arr)
|
|
return
|
|
}
|
|
|
|
// InsertAl persist algorithms
|
|
func (d *Database) InsertAl(
|
|
ctx context.Context, uid, device string, al []string,
|
|
) (err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) {
|
|
err = d.alStatements.insertAl(ctx, txn, uid, device, strings.Join(al, ","))
|
|
return
|
|
})
|
|
return
|
|
}
|
|
|
|
// SelectAl select algorithms
|
|
func (d *Database) SelectAlgo(
|
|
ctx context.Context, uid, device string,
|
|
) (res []string, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) {
|
|
holder, err := d.alStatements.selectAl(ctx, txn, uid, device)
|
|
res = strings.Split(holder.SupportedAlgorithm, ",")
|
|
return
|
|
})
|
|
return
|
|
}
|
|
|
|
// SelectOneTimeKeySingle claim for one time key one for once
|
|
func (d *Database) SelectOneTimeKeySingle(
|
|
ctx context.Context,
|
|
userID, deviceID, algorithm string,
|
|
) (holder types.KeyHolder, err error) {
|
|
holder, err = d.keyStatements.selectSingleKey(ctx, userID, deviceID, algorithm)
|
|
return
|
|
}
|
|
|
|
// SyncOneTimeCount for sync device_one_time_keys_count extension
|
|
func (d *Database) SyncOneTimeCount(
|
|
ctx context.Context,
|
|
userID, deviceID string,
|
|
) (holder map[string]int, err error) {
|
|
holder, err = d.keyStatements.selectOneTimeKeyCount(ctx, userID, deviceID)
|
|
return
|
|
}
|
|
|
|
// InsertChanges inserts the changes to the DB
|
|
func (d *Database) InsertChanges(
|
|
ctx context.Context,
|
|
readID int, userID, status string,
|
|
) error {
|
|
err := common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
return d.keyStatements.insertChanges(ctx, txn, readID, userID, status)
|
|
})
|
|
return err
|
|
}
|
|
|
|
// GetKeyChanges gets the changed keys from the DB
|
|
func (d *Database) GetKeyChanges(
|
|
ctx context.Context,
|
|
readID int, userID string,
|
|
) (keyChanges types.KeyChanges, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) {
|
|
keyChanges, err = d.keyStatements.selectChanges(ctx, txn, readID, userID)
|
|
return
|
|
})
|
|
return
|
|
}
|