2017-05-19 04:27:03 -05:00
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
2020-02-13 11:27:33 -06:00
package postgres
2017-05-19 04:27:03 -05:00
import (
2017-09-18 08:15:27 -05:00
"context"
2017-05-19 04:27:03 -05:00
"database/sql"
"time"
2022-02-16 11:55:38 -06:00
"github.com/matrix-org/gomatrixserverlib"
2018-05-31 09:21:13 -05:00
"github.com/matrix-org/dendrite/clientapi/userutil"
2020-09-24 05:10:14 -05:00
"github.com/matrix-org/dendrite/internal/sqlutil"
2020-06-17 05:22:26 -05:00
"github.com/matrix-org/dendrite/userapi/api"
2022-02-18 07:51:59 -06:00
"github.com/matrix-org/dendrite/userapi/storage/tables"
2018-06-29 05:55:29 -05:00
log "github.com/sirupsen/logrus"
2017-05-19 04:27:03 -05:00
)
const accountsSchema = `
-- Stores data about accounts .
2017-08-07 05:51:46 -05:00
CREATE TABLE IF NOT EXISTS account_accounts (
2017-05-19 04:27:03 -05:00
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY ,
-- When this account was first created , as a unix timestamp ( ms resolution ) .
created_ts BIGINT NOT NULL ,
-- The password hash for this account . Can be NULL if this is a passwordless account .
2018-02-08 05:02:48 -06:00
password_hash TEXT ,
2018-07-05 11:34:59 -05:00
-- Identifies which application service this account belongs to , if any .
2020-10-02 11:18:20 -05:00
appservice_id TEXT ,
-- If the account is currently active
2022-02-14 07:02:13 -06:00
is_deactivated BOOLEAN DEFAULT FALSE ,
2022-02-16 11:55:38 -06:00
-- The account_type ( user = 1 , guest = 2 , admin = 3 , appservice = 4 )
2022-02-21 05:08:03 -06:00
account_type SMALLINT NOT NULL ,
2022-02-14 07:02:13 -06:00
-- The policy version this user has accepted
2022-02-21 07:26:00 -06:00
policy_version TEXT ,
-- The policy version the user received from the server notices room
2022-03-07 02:41:25 -06:00
policy_version_sent TEXT ,
server_notice_room_id TEXT
2017-05-19 04:27:03 -05:00
-- TODO :
2022-02-16 11:55:38 -06:00
-- upgraded_ts , devices , any email reset stuff ?
2017-05-19 04:27:03 -05:00
) ;
2018-05-31 09:36:15 -05:00
-- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1 ;
2017-05-19 04:27:03 -05:00
`
const insertAccountSQL = "" +
2022-02-21 05:08:03 -06:00
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
2017-05-19 04:27:03 -05:00
2020-09-04 09:16:13 -05:00
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
2020-10-02 11:18:20 -05:00
const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
2017-05-19 04:27:03 -05:00
const selectAccountByLocalpartSQL = "" +
2022-02-16 11:55:38 -06:00
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
2017-05-19 04:27:03 -05:00
const selectPasswordHashSQL = "" +
2020-10-02 11:18:20 -05:00
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
2017-05-19 04:27:03 -05:00
2018-05-31 09:36:15 -05:00
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
2022-02-14 07:02:13 -06:00
const selectPrivacyPolicySQL = "" +
2022-02-14 07:52:16 -06:00
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
2022-02-14 07:02:13 -06:00
const batchSelectPrivacyPolicySQL = "" +
2022-02-21 07:26:00 -06:00
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)"
2022-02-14 07:02:13 -06:00
2022-02-14 08:08:00 -06:00
const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
2022-02-21 07:26:00 -06:00
const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
2022-03-07 02:41:25 -06:00
const selectServerNoticeRoomSQL = "" +
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
const updateServerNoticeRoomSQL = "" +
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
2017-05-19 04:27:03 -05:00
type accountsStatements struct {
2022-02-21 07:26:00 -06:00
insertAccountStmt * sql . Stmt
updatePasswordStmt * sql . Stmt
deactivateAccountStmt * sql . Stmt
selectAccountByLocalpartStmt * sql . Stmt
selectPasswordHashStmt * sql . Stmt
selectNewNumericLocalpartStmt * sql . Stmt
selectPrivacyPolicyStmt * sql . Stmt
batchSelectPrivacyPolicyStmt * sql . Stmt
updatePolicyVersionStmt * sql . Stmt
updatePolicyVersionServerNoticeStmt * sql . Stmt
2022-03-07 02:41:25 -06:00
selectServerNoticeRoomStmt * sql . Stmt
updateServerNoticeRoomStmt * sql . Stmt
2022-02-21 07:26:00 -06:00
serverName gomatrixserverlib . ServerName
2017-05-19 04:27:03 -05:00
}
2022-02-18 07:51:59 -06:00
func NewPostgresAccountsTable ( db * sql . DB , serverName gomatrixserverlib . ServerName ) ( tables . AccountsTable , error ) {
s := & accountsStatements {
serverName : serverName ,
}
2020-10-15 12:09:41 -05:00
_ , err := db . Exec ( accountsSchema )
2022-02-18 07:51:59 -06:00
if err != nil {
return nil , err
}
return s , sqlutil . StatementList {
2021-07-28 12:30:04 -05:00
{ & s . insertAccountStmt , insertAccountSQL } ,
{ & s . updatePasswordStmt , updatePasswordSQL } ,
{ & s . deactivateAccountStmt , deactivateAccountSQL } ,
{ & s . selectAccountByLocalpartStmt , selectAccountByLocalpartSQL } ,
{ & s . selectPasswordHashStmt , selectPasswordHashSQL } ,
{ & s . selectNewNumericLocalpartStmt , selectNewNumericLocalpartSQL } ,
2022-02-14 07:02:13 -06:00
{ & s . selectPrivacyPolicyStmt , selectPrivacyPolicySQL } ,
{ & s . batchSelectPrivacyPolicyStmt , batchSelectPrivacyPolicySQL } ,
2022-02-14 08:08:00 -06:00
{ & s . updatePolicyVersionStmt , updatePolicyVersionSQL } ,
2022-02-21 07:26:00 -06:00
{ & s . updatePolicyVersionServerNoticeStmt , updatePolicyVersionServerNoticeSQL } ,
2022-03-07 02:41:25 -06:00
{ & s . selectServerNoticeRoomStmt , selectServerNoticeRoomSQL } ,
{ & s . updateServerNoticeRoomStmt , updateServerNoticeRoomSQL } ,
2021-07-28 12:30:04 -05:00
} . Prepare ( db )
2017-05-19 04:27:03 -05:00
}
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
2022-02-18 07:51:59 -06:00
func ( s * accountsStatements ) InsertAccount (
2022-02-21 05:08:03 -06:00
ctx context . Context , txn * sql . Tx , localpart , hash , appserviceID , policyVersion string , accountType api . AccountType ,
2020-06-17 05:22:26 -05:00
) ( * api . Account , error ) {
2017-05-19 04:27:03 -05:00
createdTimeMS := time . Now ( ) . UnixNano ( ) / 1000000
2020-09-24 05:10:14 -05:00
stmt := sqlutil . TxStmt ( txn , s . insertAccountStmt )
2018-02-08 05:02:48 -06:00
2022-03-07 02:41:25 -06:00
_ , err := stmt . ExecContext ( ctx , localpart , createdTimeMS , hash , nil , accountType , policyVersion )
2018-02-08 05:02:48 -06:00
if err != nil {
2017-09-18 08:15:27 -05:00
return nil , err
2017-05-19 04:27:03 -05:00
}
2018-02-08 05:02:48 -06:00
2020-06-17 05:22:26 -05:00
return & api . Account {
2018-02-08 05:02:48 -06:00
Localpart : localpart ,
2018-05-31 09:21:13 -05:00
UserID : userutil . MakeUserID ( localpart , s . serverName ) ,
2018-02-08 05:02:48 -06:00
ServerName : s . serverName ,
AppServiceID : appserviceID ,
2022-02-16 11:55:38 -06:00
AccountType : accountType ,
2017-09-18 08:15:27 -05:00
} , nil
2017-05-19 04:27:03 -05:00
}
2022-02-18 07:51:59 -06:00
func ( s * accountsStatements ) UpdatePassword (
2020-09-04 09:16:13 -05:00
ctx context . Context , localpart , passwordHash string ,
) ( err error ) {
_ , err = s . updatePasswordStmt . ExecContext ( ctx , passwordHash , localpart )
return
}
2022-02-18 07:51:59 -06:00
func ( s * accountsStatements ) DeactivateAccount (
2020-10-02 11:18:20 -05:00
ctx context . Context , localpart string ,
) ( err error ) {
_ , err = s . deactivateAccountStmt . ExecContext ( ctx , localpart )
return
}
2022-02-18 07:51:59 -06:00
func ( s * accountsStatements ) SelectPasswordHash (
2017-09-18 08:15:27 -05:00
ctx context . Context , localpart string ,
) ( hash string , err error ) {
err = s . selectPasswordHashStmt . QueryRowContext ( ctx , localpart ) . Scan ( & hash )
2017-05-19 04:27:03 -05:00
return
}
2022-02-18 07:51:59 -06:00
func ( s * accountsStatements ) SelectAccountByLocalpart (
2017-09-18 08:15:27 -05:00
ctx context . Context , localpart string ,
2020-06-17 05:22:26 -05:00
) ( * api . Account , error ) {
2018-06-29 05:55:29 -05:00
var appserviceIDPtr sql . NullString
2020-06-17 05:22:26 -05:00
var acc api . Account
2018-06-29 05:55:29 -05:00
2017-09-18 08:15:27 -05:00
stmt := s . selectAccountByLocalpartStmt
2022-02-16 11:55:38 -06:00
err := stmt . QueryRowContext ( ctx , localpart ) . Scan ( & acc . Localpart , & appserviceIDPtr , & acc . AccountType )
2018-06-29 05:55:29 -05:00
if err != nil {
if err != sql . ErrNoRows {
log . WithError ( err ) . Error ( "Unable to retrieve user from the db" )
}
return nil , err
2017-05-19 04:27:03 -05:00
}
2018-06-29 05:55:29 -05:00
if appserviceIDPtr . Valid {
acc . AppServiceID = appserviceIDPtr . String
}
acc . UserID = userutil . MakeUserID ( localpart , s . serverName )
acc . ServerName = s . serverName
return & acc , nil
2017-05-19 04:27:03 -05:00
}
2018-05-31 09:36:15 -05:00
2022-02-18 07:51:59 -06:00
func ( s * accountsStatements ) SelectNewNumericLocalpart (
2020-03-06 12:00:07 -06:00
ctx context . Context , txn * sql . Tx ,
2018-05-31 09:36:15 -05:00
) ( id int64 , err error ) {
2020-03-06 12:00:07 -06:00
stmt := s . selectNewNumericLocalpartStmt
if txn != nil {
2020-09-24 05:10:14 -05:00
stmt = sqlutil . TxStmt ( txn , stmt )
2020-03-06 12:00:07 -06:00
}
err = stmt . QueryRowContext ( ctx ) . Scan ( & id )
2018-05-31 09:36:15 -05:00
return
}
2022-02-14 07:02:13 -06:00
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
2022-02-21 05:12:07 -06:00
func ( s * accountsStatements ) SelectPrivacyPolicy (
2022-02-14 07:02:13 -06:00
ctx context . Context , txn * sql . Tx , localPart string ,
) ( policy string , err error ) {
stmt := s . selectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil . TxStmt ( txn , stmt )
}
err = stmt . QueryRowContext ( ctx , localPart ) . Scan ( & policy )
return
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
2022-02-21 05:12:07 -06:00
func ( s * accountsStatements ) BatchSelectPrivacyPolicy (
2022-02-14 07:02:13 -06:00
ctx context . Context , txn * sql . Tx , policyVersion string ,
) ( userIDs [ ] string , err error ) {
stmt := s . batchSelectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil . TxStmt ( txn , stmt )
}
rows , err := stmt . QueryContext ( ctx , policyVersion )
defer rows . Close ( )
for rows . Next ( ) {
var userID string
if err := rows . Scan ( & userID ) ; err != nil {
return userIDs , err
}
userIDs = append ( userIDs , userID )
}
return userIDs , rows . Err ( )
}
2022-02-14 08:08:00 -06:00
// updatePolicyVersion sets the policy_version for a specific user
2022-02-21 05:12:07 -06:00
func ( s * accountsStatements ) UpdatePolicyVersion (
2022-02-21 07:26:00 -06:00
ctx context . Context , txn * sql . Tx , policyVersion , localpart string , serverNotice bool ,
2022-02-14 08:08:00 -06:00
) ( err error ) {
stmt := s . updatePolicyVersionStmt
2022-02-21 07:26:00 -06:00
if serverNotice {
stmt = s . updatePolicyVersionServerNoticeStmt
}
2022-02-14 08:08:00 -06:00
if txn != nil {
stmt = sqlutil . TxStmt ( txn , stmt )
}
_ , err = stmt . ExecContext ( ctx , policyVersion , localpart )
return err
}
2022-03-07 02:41:25 -06:00
// SelectServerNoticeRoomID queries the server notice room ID.
func ( s * accountsStatements ) SelectServerNoticeRoomID (
ctx context . Context , txn * sql . Tx , localpart string ,
) ( roomID string , err error ) {
stmt := s . selectServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil . TxStmt ( txn , stmt )
}
roomIDNull := sql . NullString { }
row := stmt . QueryRowContext ( ctx , localpart )
err = row . Scan ( & roomIDNull )
if err != nil {
return "" , err
}
if roomIDNull . Valid {
return roomIDNull . String , nil
}
return "" , nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func ( s * accountsStatements ) UpdateServerNoticeRoomID (
ctx context . Context , txn * sql . Tx , localpart , roomID string ,
) ( err error ) {
stmt := s . updateServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil . TxStmt ( txn , stmt )
}
_ , err = stmt . ExecContext ( ctx , roomID , localpart )
return
}