2023-10-18 23:01:16 -05:00
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
2023-11-17 18:34:01 -06:00
var oneTimeCryptoIDsSchema = `
-- Stores one - time cryptoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
2023-10-18 23:01:16 -05:00
user_id TEXT NOT NULL ,
key_id TEXT NOT NULL ,
algorithm TEXT NOT NULL ,
ts_added_secs BIGINT NOT NULL ,
key_json TEXT NOT NULL ,
-- Clobber based on 3 - uple of user / key / algorithm .
2023-11-17 18:34:01 -06:00
CONSTRAINT keyserver_one_time_cryptoids_unique UNIQUE ( user_id , key_id , algorithm )
2023-10-18 23:01:16 -05:00
) ;
2023-11-17 18:34:01 -06:00
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids ( user_id ) ;
2023-10-18 23:01:16 -05:00
`
2023-11-17 18:34:01 -06:00
const upsertCryptoIDsSQL = "" +
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
2023-10-18 23:01:16 -05:00
" VALUES ($1, $2, $3, $4, $5)" +
2023-11-17 18:34:01 -06:00
" ON CONFLICT ON CONSTRAINT keyserver_one_time_cryptoids_unique" +
2023-10-18 23:01:16 -05:00
" DO UPDATE SET key_json = $5"
2023-11-17 18:34:01 -06:00
const selectOneTimeCryptoIDsSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
2023-10-18 23:01:16 -05:00
2023-11-17 18:34:01 -06:00
const selectCryptoIDsCountSQL = "" +
2023-10-18 23:01:16 -05:00
"SELECT algorithm, COUNT(key_id) FROM " +
2023-11-17 18:34:01 -06:00
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
2023-10-18 23:01:16 -05:00
" x GROUP BY algorithm"
2023-11-17 18:34:01 -06:00
const deleteOneTimeCryptoIDSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
2023-10-18 23:01:16 -05:00
2023-11-17 18:34:01 -06:00
const selectCryptoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
2023-10-18 23:01:16 -05:00
2023-11-17 18:34:01 -06:00
const deleteOneTimeCryptoIDsSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
2023-10-18 23:01:16 -05:00
2023-11-17 18:34:01 -06:00
type oneTimeCryptoIDsStatements struct {
2023-10-18 23:01:16 -05:00
db * sql . DB
2023-11-17 18:34:01 -06:00
upsertCryptoIDsStmt * sql . Stmt
selectCryptoIDsStmt * sql . Stmt
selectCryptoIDsCountStmt * sql . Stmt
selectCryptoIDByAlgorithmStmt * sql . Stmt
deleteOneTimeCryptoIDStmt * sql . Stmt
deleteOneTimeCryptoIDsStmt * sql . Stmt
2023-10-18 23:01:16 -05:00
}
2023-11-17 18:34:01 -06:00
func NewPostgresOneTimeCryptoIDsTable ( db * sql . DB ) ( tables . OneTimeCryptoIDs , error ) {
s := & oneTimeCryptoIDsStatements {
2023-10-18 23:01:16 -05:00
db : db ,
}
2023-11-17 18:34:01 -06:00
_ , err := db . Exec ( oneTimeCryptoIDsSchema )
2023-10-18 23:01:16 -05:00
if err != nil {
return nil , err
}
return s , sqlutil . StatementList {
2023-11-17 18:34:01 -06:00
{ & s . upsertCryptoIDsStmt , upsertCryptoIDsSQL } ,
{ & s . selectCryptoIDsStmt , selectOneTimeCryptoIDsSQL } ,
{ & s . selectCryptoIDsCountStmt , selectCryptoIDsCountSQL } ,
{ & s . selectCryptoIDByAlgorithmStmt , selectCryptoIDByAlgorithmSQL } ,
{ & s . deleteOneTimeCryptoIDStmt , deleteOneTimeCryptoIDSQL } ,
{ & s . deleteOneTimeCryptoIDsStmt , deleteOneTimeCryptoIDsSQL } ,
2023-10-18 23:01:16 -05:00
} . Prepare ( db )
}
2023-11-17 18:34:01 -06:00
func ( s * oneTimeCryptoIDsStatements ) SelectOneTimeCryptoIDs ( ctx context . Context , userID string , keyIDsWithAlgorithms [ ] string ) ( map [ string ] json . RawMessage , error ) {
rows , err := s . selectCryptoIDsStmt . QueryContext ( ctx , userID , pq . Array ( keyIDsWithAlgorithms ) )
2023-10-18 23:01:16 -05:00
if err != nil {
return nil , err
}
2023-11-17 18:34:01 -06:00
defer internal . CloseAndLogIfError ( ctx , rows , "selectCryptoIDsStmt: rows.close() failed" )
2023-10-18 23:01:16 -05:00
result := make ( map [ string ] json . RawMessage )
var (
algorithmWithID string
keyJSONStr string
)
for rows . Next ( ) {
if err := rows . Scan ( & algorithmWithID , & keyJSONStr ) ; err != nil {
return nil , err
}
result [ algorithmWithID ] = json . RawMessage ( keyJSONStr )
}
return result , rows . Err ( )
}
2023-11-17 18:34:01 -06:00
func ( s * oneTimeCryptoIDsStatements ) CountOneTimeCryptoIDs ( ctx context . Context , userID string ) ( * api . OneTimeCryptoIDsCount , error ) {
counts := & api . OneTimeCryptoIDsCount {
2023-10-18 23:01:16 -05:00
UserID : userID ,
KeyCount : make ( map [ string ] int ) ,
}
2023-11-17 18:34:01 -06:00
rows , err := s . selectCryptoIDsCountStmt . QueryContext ( ctx , userID )
2023-10-18 23:01:16 -05:00
if err != nil {
return nil , err
}
2023-11-17 18:34:01 -06:00
defer internal . CloseAndLogIfError ( ctx , rows , "selectCryptoIDsCountStmt: rows.close() failed" )
2023-10-18 23:01:16 -05:00
for rows . Next ( ) {
var algorithm string
var count int
if err = rows . Scan ( & algorithm , & count ) ; err != nil {
return nil , err
}
counts . KeyCount [ algorithm ] = count
}
return counts , nil
}
2023-11-17 18:34:01 -06:00
func ( s * oneTimeCryptoIDsStatements ) InsertOneTimeCryptoIDs ( ctx context . Context , txn * sql . Tx , keys api . OneTimeCryptoIDs ) ( * api . OneTimeCryptoIDsCount , error ) {
2023-10-18 23:01:16 -05:00
now := time . Now ( ) . Unix ( )
2023-11-17 18:34:01 -06:00
counts := & api . OneTimeCryptoIDsCount {
2023-10-18 23:01:16 -05:00
UserID : keys . UserID ,
KeyCount : make ( map [ string ] int ) ,
}
for keyIDWithAlgo , keyJSON := range keys . KeyJSON {
algo , keyID := keys . Split ( keyIDWithAlgo )
2023-11-17 18:34:01 -06:00
_ , err := sqlutil . TxStmt ( txn , s . upsertCryptoIDsStmt ) . ExecContext (
2023-10-18 23:01:16 -05:00
ctx , keys . UserID , keyID , algo , now , string ( keyJSON ) ,
)
if err != nil {
return nil , err
}
}
2023-11-17 18:34:01 -06:00
rows , err := sqlutil . TxStmt ( txn , s . selectCryptoIDsCountStmt ) . QueryContext ( ctx , keys . UserID )
2023-10-18 23:01:16 -05:00
if err != nil {
return nil , err
}
2023-11-17 18:34:01 -06:00
defer internal . CloseAndLogIfError ( ctx , rows , "selectCryptoIDsCountStmt: rows.close() failed" )
2023-10-18 23:01:16 -05:00
for rows . Next ( ) {
var algorithm string
var count int
if err = rows . Scan ( & algorithm , & count ) ; err != nil {
return nil , err
}
counts . KeyCount [ algorithm ] = count
}
return counts , rows . Err ( )
}
2023-11-17 18:34:01 -06:00
func ( s * oneTimeCryptoIDsStatements ) SelectAndDeleteOneTimeCryptoID (
2023-10-18 23:01:16 -05:00
ctx context . Context , txn * sql . Tx , userID , algorithm string ,
) ( map [ string ] json . RawMessage , error ) {
var keyID string
var keyJSON string
2023-11-17 18:34:01 -06:00
err := sqlutil . TxStmtContext ( ctx , txn , s . selectCryptoIDByAlgorithmStmt ) . QueryRowContext ( ctx , userID , algorithm ) . Scan ( & keyID , & keyJSON )
2023-10-18 23:01:16 -05:00
if err != nil {
if err == sql . ErrNoRows {
return nil , nil
}
return nil , err
}
2023-11-17 18:34:01 -06:00
_ , err = sqlutil . TxStmtContext ( ctx , txn , s . deleteOneTimeCryptoIDStmt ) . ExecContext ( ctx , userID , algorithm , keyID )
2023-10-18 23:01:16 -05:00
return map [ string ] json . RawMessage {
algorithm + ":" + keyID : json . RawMessage ( keyJSON ) ,
} , err
}
2023-11-17 18:34:01 -06:00
func ( s * oneTimeCryptoIDsStatements ) DeleteOneTimeCryptoIDs ( ctx context . Context , txn * sql . Tx , userID string ) error {
_ , err := sqlutil . TxStmt ( txn , s . deleteOneTimeCryptoIDsStmt ) . ExecContext ( ctx , userID )
2023-10-18 23:01:16 -05:00
return err
}