mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 01:33:10 -06:00
Refactor threepid sessions storage so that postgres and sqlite3 schemas are separate. Generate session id in database
This commit is contained in:
parent
bae4313b4b
commit
316adf66d1
|
|
@ -438,7 +438,7 @@ type CreateSessionRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateSessionResponse struct {
|
type CreateSessionResponse struct {
|
||||||
Sid string
|
Sid int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type ValidateSessionRequest struct {
|
type ValidateSessionRequest struct {
|
||||||
|
|
@ -451,13 +451,16 @@ type GetThreePidForSessionResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionOwnership struct {
|
type SessionOwnership struct {
|
||||||
Sid, ClientSecret string
|
Sid int64
|
||||||
|
ClientSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
Sid, ClientSecret, ThreePid, Token, NextLink string
|
ClientSecret, ThreePid, Token, NextLink string
|
||||||
SendAttempt, ValidatedAt int
|
Sid int64
|
||||||
Validated bool
|
SendAttempt int
|
||||||
|
ValidatedAt int64
|
||||||
|
Validated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type IsSessionValidatedResponse struct {
|
type IsSessionValidatedResponse struct {
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
|
@ -28,19 +29,14 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess
|
||||||
if errB != nil {
|
if errB != nil {
|
||||||
return errB
|
return errB
|
||||||
}
|
}
|
||||||
sid, errB := internal.GenerateBlob(sessionIdByteLength)
|
|
||||||
if errB != nil {
|
|
||||||
return errB
|
|
||||||
}
|
|
||||||
s = &api.Session{
|
s = &api.Session{
|
||||||
Sid: sid,
|
|
||||||
ClientSecret: req.ClientSecret,
|
ClientSecret: req.ClientSecret,
|
||||||
ThreePid: req.ThreePid,
|
ThreePid: req.ThreePid,
|
||||||
SendAttempt: req.SendAttempt,
|
SendAttempt: req.SendAttempt,
|
||||||
Token: token,
|
Token: token,
|
||||||
NextLink: req.NextLink,
|
NextLink: req.NextLink,
|
||||||
}
|
}
|
||||||
err = a.ThreePidDB.InsertSession(ctx, s)
|
s.Sid, err = a.ThreePidDB.InsertSession(ctx, req.ClientSecret, req.ThreePid, token, req.NextLink, 0, false, req.SendAttempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -60,7 +56,7 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess
|
||||||
}
|
}
|
||||||
res.Sid = s.Sid
|
res.Sid = s.Sid
|
||||||
query := url.Values{
|
query := url.Values{
|
||||||
"sid": []string{s.Sid},
|
"sid": []string{strconv.Itoa(int(s.Sid))},
|
||||||
"client_secret": []string{s.ClientSecret},
|
"client_secret": []string{s.ClientSecret},
|
||||||
"token": []string{s.Token},
|
"token": []string{s.Token},
|
||||||
}
|
}
|
||||||
|
|
@ -88,7 +84,7 @@ func (a *UserInternalAPI) ValidateSession(ctx context.Context, req *api.Validate
|
||||||
if s.Token != req.Token {
|
if s.Token != req.Token {
|
||||||
return ErrBadSession
|
return ErrBadSession
|
||||||
}
|
}
|
||||||
return a.ThreePidDB.ValidateSession(ctx, s.Sid, int(time.Now().Unix()))
|
return a.ThreePidDB.ValidateSession(ctx, s.Sid, time.Now().Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error {
|
func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error {
|
||||||
|
|
@ -114,7 +110,7 @@ func (a *UserInternalAPI) IsSessionValidated(ctx context.Context, req *api.Sessi
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.Validated = s.Validated
|
res.Validated = s.Validated
|
||||||
res.ValidatedAt = s.ValidatedAt
|
res.ValidatedAt = int(s.ValidatedAt)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
59
userapi/storage/threepid/postgres/storage.go
Normal file
59
userapi/storage/threepid/postgres/storage.go
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/threepid/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Database struct {
|
||||||
|
*shared.Database
|
||||||
|
}
|
||||||
|
|
||||||
|
const threePidSessionsSchema = `
|
||||||
|
-- This sequence is used for automatic allocation of session_id.
|
||||||
|
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
|
||||||
|
|
||||||
|
-- Stores data about devices.
|
||||||
|
CREATE TABLE IF NOT EXISTS threepid_sessions (
|
||||||
|
sid SERIAL PRIMARY KEY,
|
||||||
|
client_secret VARCHAR(255),
|
||||||
|
threepid TEXT ,
|
||||||
|
token VARCHAR(255) ,
|
||||||
|
next_link TEXT,
|
||||||
|
validated_at_ts BIGINT,
|
||||||
|
validated BOOLEAN,
|
||||||
|
send_attempt INT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
|
||||||
|
ON threepid_sessions (threepid, client_secret)
|
||||||
|
`
|
||||||
|
|
||||||
|
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
|
db, err := sqlutil.Open(dbProperties)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sqlutil.Open: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the tables.
|
||||||
|
if err := create(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedDb, err := shared.NewDatabase(db, sqlutil.NewDummyWriter())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{
|
||||||
|
Database: sharedDb,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func create(db *sql.DB) error {
|
||||||
|
_, err := db.Exec(threePidSessionsSchema)
|
||||||
|
return err
|
||||||
|
}
|
||||||
136
userapi/storage/threepid/shared/storage.go
Normal file
136
userapi/storage/threepid/shared/storage.go
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Database struct {
|
||||||
|
writer sqlutil.Writer
|
||||||
|
stm *ThreePidSessionStatements
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) InsertSession(
|
||||||
|
ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error) {
|
||||||
|
var sid int64
|
||||||
|
return sid, d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||||
|
err := d.stm.insertSessionStmt.QueryRowContext(ctx, clientSecret, threepid, token, nextlink, validatedAt, validated, sendAttempt).Scan(&sid)
|
||||||
|
// _, err := d.stm.insertSessionStmt.ExecContext(ctx, clientSecret, threepid, token, nextlink, sendAttempt, validatedAt, validated)
|
||||||
|
return err
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// err = d.stm.selectSidStmt.QueryRowContext(ctx).Scan(&sid)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetSession(ctx context.Context, sid int64) (*api.Session, error) {
|
||||||
|
s := api.Session{}
|
||||||
|
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt)
|
||||||
|
s.Sid = sid
|
||||||
|
return &s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
|
||||||
|
s := api.Session{}
|
||||||
|
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
|
||||||
|
QueryRowContext(ctx, threePid, ClientSecret).Scan(
|
||||||
|
&s.Sid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt)
|
||||||
|
s.ThreePid = threePid
|
||||||
|
s.ClientSecret = ClientSecret
|
||||||
|
return &s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error {
|
||||||
|
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||||
|
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) DeleteSession(ctx context.Context, sid int64) error {
|
||||||
|
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||||
|
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) ValidateSession(ctx context.Context, sid int64, validatedAt int64) error {
|
||||||
|
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||||
|
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDatabase(db *sql.DB, writer sqlutil.Writer) (*Database, error) {
|
||||||
|
threePidSessionsTable, err := PrepareThreePidSessionsTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d := Database{
|
||||||
|
writer: writer,
|
||||||
|
stm: threePidSessionsTable,
|
||||||
|
}
|
||||||
|
return &d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||||
|
// db, err := sqlutil.Open(dbProperties)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// writer := sqlutil.NewExclusiveWriter()
|
||||||
|
// stmt := sessionStatements{
|
||||||
|
// db: db,
|
||||||
|
// writer: writer,
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Create tables before executing migrations so we don't fail if the table is missing,
|
||||||
|
// // and THEN prepare statements so we don't fail due to referencing new columns
|
||||||
|
// if err = stmt.execSchema(db); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// if err = stmt.prepare(); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// handler := func(f func(tx *sql.Tx) error) error {
|
||||||
|
// return writer.Do(nil, nil, f)
|
||||||
|
// }
|
||||||
|
// return &Db{db, writer, &stmt, handler}, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||||
|
// db, err := sqlutil.Open(dbProperties)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// stmt := sessionStatements{
|
||||||
|
// db: db,
|
||||||
|
// }
|
||||||
|
// // Create tables before executing migrations so we don't fail if the table is missing,
|
||||||
|
// // and THEN prepare statements so we don't fail due to referencing new columns
|
||||||
|
// if err = stmt.execSchema(db); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// if err = stmt.prepare(); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// handler := func(f func(tx *sql.Tx) error) error {
|
||||||
|
// return f(nil)
|
||||||
|
// }
|
||||||
|
// return &Db{db, nil, &stmt, handler}, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||||
|
// switch {
|
||||||
|
// case dbProperties.ConnectionString.IsSQLite():
|
||||||
|
// return newSQLiteDatabase(dbProperties)
|
||||||
|
// case dbProperties.ConnectionString.IsPostgres():
|
||||||
|
// return newPostgresDatabase(dbProperties)
|
||||||
|
// default:
|
||||||
|
// return nil, fmt.Errorf("unexpected database type")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
49
userapi/storage/threepid/shared/threepid_sessions_table.go
Normal file
49
userapi/storage/threepid/shared/threepid_sessions_table.go
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
insertSessionSQL = "" +
|
||||||
|
"INSERT INTO threepid_sessions (client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt)" +
|
||||||
|
"VALUES ($1, $2, $3, $4, $5, $6, $7)" +
|
||||||
|
"RETURNING sid;"
|
||||||
|
// selectSidSQL = "" +
|
||||||
|
// "SELECT last_insert_rowid();"
|
||||||
|
selectSessionSQL = "" +
|
||||||
|
"SELECT client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE sid = $1"
|
||||||
|
selectSessionByThreePidAndCLientSecretSQL = "" +
|
||||||
|
"SELECT sid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE threepid = $1 AND client_secret = $2"
|
||||||
|
deleteSessionSQL = "" +
|
||||||
|
"DELETE FROM threepid_sessions WHERE sid = $1"
|
||||||
|
validateSessionSQL = "" +
|
||||||
|
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
|
||||||
|
updateSendAttemptNextLinkSQL = "" +
|
||||||
|
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ThreePidSessionStatements struct {
|
||||||
|
insertSessionStmt *sql.Stmt
|
||||||
|
// selectSidStmt *sql.Stmt
|
||||||
|
selectSessionStmt *sql.Stmt
|
||||||
|
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
|
||||||
|
deleteSessionStmt *sql.Stmt
|
||||||
|
validateSessionStmt *sql.Stmt
|
||||||
|
updateSendAttemptNextLinkStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareThreePidSessionsTable(db *sql.DB) (*ThreePidSessionStatements, error) {
|
||||||
|
s := ThreePidSessionStatements{}
|
||||||
|
return &s, sqlutil.StatementList{
|
||||||
|
{&s.insertSessionStmt, insertSessionSQL},
|
||||||
|
// {&s.selectSidStmt, selectSidSQL},
|
||||||
|
{&s.selectSessionStmt, selectSessionSQL},
|
||||||
|
{&s.selectSessionByThreePidAndCLientSecretStmt, selectSessionByThreePidAndCLientSecretSQL},
|
||||||
|
{&s.deleteSessionStmt, deleteSessionSQL},
|
||||||
|
{&s.validateSessionStmt, validateSessionSQL},
|
||||||
|
{&s.updateSendAttemptNextLinkStmt, updateSendAttemptNextLinkSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
59
userapi/storage/threepid/sqlite3/storage.go
Normal file
59
userapi/storage/threepid/sqlite3/storage.go
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/threepid/shared"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Database struct {
|
||||||
|
*shared.Database
|
||||||
|
}
|
||||||
|
|
||||||
|
const threePidSessionsSchema = `
|
||||||
|
-- This sequence is used for automatic allocation of session_id.
|
||||||
|
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
|
||||||
|
|
||||||
|
-- Stores data about devices.
|
||||||
|
CREATE TABLE IF NOT EXISTS threepid_sessions (
|
||||||
|
sid INTEGER PRIMARY KEY,
|
||||||
|
client_secret VARCHAR(255),
|
||||||
|
threepid TEXT ,
|
||||||
|
token VARCHAR(255) ,
|
||||||
|
next_link TEXT,
|
||||||
|
validated_at_ts BIGINT,
|
||||||
|
validated BOOLEAN,
|
||||||
|
send_attempt INT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
|
||||||
|
ON threepid_sessions (threepid, client_secret)
|
||||||
|
`
|
||||||
|
|
||||||
|
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
|
db, err := sqlutil.Open(dbProperties)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sqlutil.Open: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the tables.
|
||||||
|
if err := create(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedDb, err := shared.NewDatabase(db, sqlutil.NewExclusiveWriter())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{
|
||||||
|
Database: sharedDb,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func create(db *sql.DB) error {
|
||||||
|
_, err := db.Exec(threePidSessionsSchema)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
@ -1,81 +0,0 @@
|
||||||
package threepid
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
const sessionsSchema = `
|
|
||||||
-- This sequence is used for automatic allocation of session_id.
|
|
||||||
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
|
|
||||||
|
|
||||||
-- Stores data about devices.
|
|
||||||
CREATE TABLE IF NOT EXISTS threepid_sessions (
|
|
||||||
sid VARCHAR(255) PRIMARY KEY,
|
|
||||||
client_secret VARCHAR(255),
|
|
||||||
threepid TEXT ,
|
|
||||||
token VARCHAR(255) ,
|
|
||||||
next_link TEXT,
|
|
||||||
validated_at_ts BIGINT,
|
|
||||||
validated BOOLEAN,
|
|
||||||
send_attempt INT
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
|
|
||||||
ON threepid_sessions (threepid, client_secret)
|
|
||||||
`
|
|
||||||
|
|
||||||
const (
|
|
||||||
insertSessionSQL = "" +
|
|
||||||
"INSERT INTO threepid_sessions (sid, client_secret, threepid, token, next_link, send_attempt, validated_at_ts, validated)" +
|
|
||||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
|
||||||
selectSessionSQL = "" +
|
|
||||||
"SELECT client_secret, threepid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE sid = $1"
|
|
||||||
selectSessionByThreePidAndCLientSecretSQL = "" +
|
|
||||||
"SELECT sid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE threepid = $1 AND client_secret = $2"
|
|
||||||
deleteSessionSQL = "" +
|
|
||||||
"DELETE FROM threepid_sessions WHERE sid = $1"
|
|
||||||
validateSessionSQL = "" +
|
|
||||||
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
|
|
||||||
updateSendAttemptNextLinkSQL = "" +
|
|
||||||
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type sessionStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
insertSessionStmt *sql.Stmt
|
|
||||||
selectSessionStmt *sql.Stmt
|
|
||||||
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
|
|
||||||
deleteSessionStmt *sql.Stmt
|
|
||||||
validateSessionStmt *sql.Stmt
|
|
||||||
updateSendAttemptNextLinkStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sessionStatements) prepare() (err error) {
|
|
||||||
if s.insertSessionStmt, err = s.db.Prepare(insertSessionSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectSessionStmt, err = s.db.Prepare(selectSessionSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectSessionByThreePidAndCLientSecretStmt, err = s.db.Prepare(selectSessionByThreePidAndCLientSecretSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteSessionStmt, err = s.db.Prepare(deleteSessionSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.validateSessionStmt, err = s.db.Prepare(validateSessionSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateSendAttemptNextLinkStmt, err = s.db.Prepare(updateSendAttemptNextLinkSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sessionStatements) execSchema(db *sql.DB) error {
|
|
||||||
_, err := db.Exec(sessionsSchema)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
@ -2,132 +2,30 @@ package threepid
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/threepid/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/threepid/sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
InsertSession(context.Context, *api.Session) error
|
InsertSession(ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error)
|
||||||
GetSession(ctx context.Context, sid string) (*api.Session, error)
|
GetSession(ctx context.Context, sid int64) (*api.Session, error)
|
||||||
GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error)
|
GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error)
|
||||||
UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error
|
UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error
|
||||||
DeleteSession(ctx context.Context, sid string) error
|
DeleteSession(ctx context.Context, sid int64) error
|
||||||
ValidateSession(ctx context.Context, sid string, validatedAt int) error
|
ValidateSession(ctx context.Context, sid int64, validatedAt int64) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Db struct {
|
// Open opens a database connection.
|
||||||
db *sql.DB
|
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||||
writer sqlutil.Writer
|
|
||||||
stm *sessionStatements
|
|
||||||
writeHandler func(func(*sql.Tx) error) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Db) InsertSession(ctx context.Context, s *api.Session) error {
|
|
||||||
h := func(_ *sql.Tx) error {
|
|
||||||
_, err := d.stm.insertSessionStmt.ExecContext(ctx, s.Sid, s.ClientSecret, s.ThreePid, s.Token, s.NextLink, s.SendAttempt, s.ValidatedAt, s.Validated)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return d.writeHandler(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Db) GetSession(ctx context.Context, sid string) (*api.Session, error) {
|
|
||||||
s := api.Session{}
|
|
||||||
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
|
|
||||||
s.Sid = sid
|
|
||||||
return &s, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Db) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
|
|
||||||
s := api.Session{}
|
|
||||||
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
|
|
||||||
QueryRowContext(ctx, threePid, ClientSecret).Scan(
|
|
||||||
&s.Sid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
|
|
||||||
s.ThreePid = threePid
|
|
||||||
s.ClientSecret = ClientSecret
|
|
||||||
return &s, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Db) UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error {
|
|
||||||
h := func(_ *sql.Tx) error {
|
|
||||||
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return d.writeHandler(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Db) DeleteSession(ctx context.Context, sid string) error {
|
|
||||||
h := func(_ *sql.Tx) error {
|
|
||||||
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return d.writeHandler(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Db) ValidateSession(ctx context.Context, sid string, validatedAt int) error {
|
|
||||||
h := func(_ *sql.Tx) error {
|
|
||||||
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return d.writeHandler(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
|
||||||
db, err := sqlutil.Open(dbProperties)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
writer := sqlutil.NewExclusiveWriter()
|
|
||||||
stmt := sessionStatements{
|
|
||||||
db: db,
|
|
||||||
writer: writer,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
|
||||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
|
||||||
if err = stmt.execSchema(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = stmt.prepare(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
handler := func(f func(tx *sql.Tx) error) error {
|
|
||||||
return writer.Do(nil, nil, f)
|
|
||||||
}
|
|
||||||
return &Db{db, writer, &stmt, handler}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
|
||||||
db, err := sqlutil.Open(dbProperties)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stmt := sessionStatements{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
|
||||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
|
||||||
if err = stmt.execSchema(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = stmt.prepare(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
handler := func(f func(tx *sql.Tx) error) error {
|
|
||||||
return f(nil)
|
|
||||||
}
|
|
||||||
return &Db{db, nil, &stmt, handler}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return newSQLiteDatabase(dbProperties)
|
return sqlite3.Open(dbProperties)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return newPostgresDatabase(dbProperties)
|
return postgres.Open(dbProperties)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
return nil, fmt.Errorf("unexpected database type")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var testSession = api.Session{
|
var testSession = api.Session{
|
||||||
// Sid: "123",
|
|
||||||
ClientSecret: "099azAZ.=_-",
|
ClientSecret: "099azAZ.=_-",
|
||||||
ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com",
|
ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com",
|
||||||
Token: "fooBAR123",
|
Token: "fooBAR123",
|
||||||
|
|
@ -26,15 +25,24 @@ var testSession = api.Session{
|
||||||
|
|
||||||
var testCtx = context.Background()
|
var testCtx = context.Background()
|
||||||
|
|
||||||
func mustNewDatabaseWithTestSession(is *is.I) *Db {
|
func mustNewDatabaseWithTestSession(is *is.I) Database {
|
||||||
randPostfix := strconv.Itoa(rand.Int())
|
randPostfix := strconv.Itoa(rand.Int())
|
||||||
dbPath := os.TempDir() + "/dendrite-" + randPostfix
|
dbPath := os.TempDir() + "/dendrite-" + randPostfix
|
||||||
println(dbPath)
|
println(dbPath)
|
||||||
dut, err := newSQLiteDatabase(&config.DatabaseOptions{
|
dut, err := Open(&config.DatabaseOptions{
|
||||||
ConnectionString: config.DataSource("file:" + dbPath),
|
ConnectionString: config.DataSource("file:" + dbPath),
|
||||||
})
|
})
|
||||||
is.NoErr(err)
|
is.NoErr(err)
|
||||||
err = dut.InsertSession(testCtx, &testSession)
|
sid, err := dut.InsertSession(
|
||||||
|
testCtx,
|
||||||
|
testSession.ClientSecret,
|
||||||
|
testSession.ThreePid,
|
||||||
|
testSession.Token,
|
||||||
|
testSession.NextLink,
|
||||||
|
testSession.ValidatedAt,
|
||||||
|
testSession.Validated,
|
||||||
|
testSession.SendAttempt)
|
||||||
|
testSession.Sid = sid
|
||||||
is.NoErr(err)
|
is.NoErr(err)
|
||||||
return dut
|
return dut
|
||||||
}
|
}
|
||||||
|
|
@ -78,7 +86,7 @@ func TestDeleteSession(t *testing.T) {
|
||||||
func TestValidateSession(t *testing.T) {
|
func TestValidateSession(t *testing.T) {
|
||||||
is := is.New(t)
|
is := is.New(t)
|
||||||
dut := mustNewDatabaseWithTestSession(is)
|
dut := mustNewDatabaseWithTestSession(is)
|
||||||
validatedAt := 1_623_406_296
|
validatedAt := int64(1_623_406_296)
|
||||||
err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt)
|
err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt)
|
||||||
is.NoErr(err)
|
is.NoErr(err)
|
||||||
session, err := dut.GetSession(testCtx, testSession.Sid)
|
session, err := dut.GetSession(testCtx, testSession.Sid)
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ func NewInternalAPI(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to device db")
|
logrus.WithError(err).Panicf("failed to connect to device db")
|
||||||
}
|
}
|
||||||
threepidDb, err := threepid.NewDatabase(&cfg.ThreepidDatabase)
|
threepidDb, err := threepid.Open(&cfg.ThreepidDatabase)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to threepid db")
|
logrus.WithError(err).Panicf("failed to connect to threepid db")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ func TestCreateSession_Twice(t *testing.T) {
|
||||||
resp := api.CreateSessionResponse{}
|
resp := api.CreateSessionResponse{}
|
||||||
err := internalApi.CreateSession(ctx, testReq, &resp)
|
err := internalApi.CreateSession(ctx, testReq, &resp)
|
||||||
is.NoErr(err)
|
is.NoErr(err)
|
||||||
is.Equal(len(resp.Sid), 43)
|
is.Equal(resp.Sid, int64(1))
|
||||||
select {
|
select {
|
||||||
case <-mailer.c[api.AccountPassword]:
|
case <-mailer.c[api.AccountPassword]:
|
||||||
t.Fatal("email was received, but sent attempt was not increased")
|
t.Fatal("email was received, but sent attempt was not increased")
|
||||||
|
|
@ -189,7 +189,7 @@ func TestCreateSession_Twice_IncreaseSendAttempt(t *testing.T) {
|
||||||
testReqBumped.SendAttempt = 1
|
testReqBumped.SendAttempt = 1
|
||||||
err := internalApi.CreateSession(ctx, &testReqBumped, &resp)
|
err := internalApi.CreateSession(ctx, &testReqBumped, &resp)
|
||||||
is.NoErr(err)
|
is.NoErr(err)
|
||||||
is.Equal(len(resp.Sid), 43)
|
is.Equal(resp.Sid, int64(1))
|
||||||
sub := <-mailer.c[api.AccountPassword]
|
sub := <-mailer.c[api.AccountPassword]
|
||||||
is.Equal(len(sub.Token), 64)
|
is.Equal(len(sub.Token), 64)
|
||||||
is.Equal(sub.To, testReq.ThreePid)
|
is.Equal(sub.To, testReq.ThreePid)
|
||||||
|
|
@ -235,7 +235,7 @@ func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateS
|
||||||
i.Mail = mailer
|
i.Mail = mailer
|
||||||
err := i.CreateSession(ctx, testReq, resp)
|
err := i.CreateSession(ctx, testReq, resp)
|
||||||
is.NoErr(err)
|
is.NoErr(err)
|
||||||
is.Equal(len(resp.Sid), 43)
|
is.Equal(resp.Sid, int64(1))
|
||||||
sub := <-mailer.c[api.AccountPassword]
|
sub := <-mailer.c[api.AccountPassword]
|
||||||
is.Equal(len(sub.Token), 64)
|
is.Equal(len(sub.Token), 64)
|
||||||
is.Equal(sub.To, testReq.ThreePid)
|
is.Equal(sub.To, testReq.ThreePid)
|
||||||
|
|
@ -244,14 +244,14 @@ func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateS
|
||||||
is.Equal(submitUrl.Host, "example.com")
|
is.Equal(submitUrl.Host, "example.com")
|
||||||
is.Equal(submitUrl.Path, "/_matrix/client/r0/account/password/email/submitToken")
|
is.Equal(submitUrl.Path, "/_matrix/client/r0/account/password/email/submitToken")
|
||||||
q := submitUrl.Query()
|
q := submitUrl.Query()
|
||||||
is.Equal(len(q["sid"][0]), 43)
|
is.Equal(q["sid"][0], "1")
|
||||||
is.Equal(q["token"][0], sub.Token)
|
is.Equal(q["token"][0], sub.Token)
|
||||||
is.Equal(q["client_secret"][0], "foobar")
|
is.Equal(q["client_secret"][0], "foobar")
|
||||||
token = sub.Token
|
token = sub.Token
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token, sid string) {
|
func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token string, sid int64) {
|
||||||
err := i.ValidateSession(ctx, &api.ValidateSessionRequest{
|
err := i.ValidateSession(ctx, &api.ValidateSessionRequest{
|
||||||
SessionOwnership: api.SessionOwnership{
|
SessionOwnership: api.SessionOwnership{
|
||||||
Sid: sid,
|
Sid: sid,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue