mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 01:33:10 -06:00
Refactor threepid sessions storage so that postgres and sqlite3 schemas are separate. Generate session id in database
This commit is contained in:
parent
bae4313b4b
commit
316adf66d1
|
|
@ -438,7 +438,7 @@ type CreateSessionRequest struct {
|
|||
}
|
||||
|
||||
type CreateSessionResponse struct {
|
||||
Sid string
|
||||
Sid int64
|
||||
}
|
||||
|
||||
type ValidateSessionRequest struct {
|
||||
|
|
@ -451,13 +451,16 @@ type GetThreePidForSessionResponse struct {
|
|||
}
|
||||
|
||||
type SessionOwnership struct {
|
||||
Sid, ClientSecret string
|
||||
Sid int64
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Sid, ClientSecret, ThreePid, Token, NextLink string
|
||||
SendAttempt, ValidatedAt int
|
||||
Validated bool
|
||||
ClientSecret, ThreePid, Token, NextLink string
|
||||
Sid int64
|
||||
SendAttempt int
|
||||
ValidatedAt int64
|
||||
Validated bool
|
||||
}
|
||||
|
||||
type IsSessionValidatedResponse struct {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
|
|
@ -28,19 +29,14 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess
|
|||
if errB != nil {
|
||||
return errB
|
||||
}
|
||||
sid, errB := internal.GenerateBlob(sessionIdByteLength)
|
||||
if errB != nil {
|
||||
return errB
|
||||
}
|
||||
s = &api.Session{
|
||||
Sid: sid,
|
||||
ClientSecret: req.ClientSecret,
|
||||
ThreePid: req.ThreePid,
|
||||
SendAttempt: req.SendAttempt,
|
||||
Token: token,
|
||||
NextLink: req.NextLink,
|
||||
}
|
||||
err = a.ThreePidDB.InsertSession(ctx, s)
|
||||
s.Sid, err = a.ThreePidDB.InsertSession(ctx, req.ClientSecret, req.ThreePid, token, req.NextLink, 0, false, req.SendAttempt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -60,7 +56,7 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess
|
|||
}
|
||||
res.Sid = s.Sid
|
||||
query := url.Values{
|
||||
"sid": []string{s.Sid},
|
||||
"sid": []string{strconv.Itoa(int(s.Sid))},
|
||||
"client_secret": []string{s.ClientSecret},
|
||||
"token": []string{s.Token},
|
||||
}
|
||||
|
|
@ -88,7 +84,7 @@ func (a *UserInternalAPI) ValidateSession(ctx context.Context, req *api.Validate
|
|||
if s.Token != req.Token {
|
||||
return ErrBadSession
|
||||
}
|
||||
return a.ThreePidDB.ValidateSession(ctx, s.Sid, int(time.Now().Unix()))
|
||||
return a.ThreePidDB.ValidateSession(ctx, s.Sid, time.Now().Unix())
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error {
|
||||
|
|
@ -114,7 +110,7 @@ func (a *UserInternalAPI) IsSessionValidated(ctx context.Context, req *api.Sessi
|
|||
return err
|
||||
}
|
||||
res.Validated = s.Validated
|
||||
res.ValidatedAt = s.ValidatedAt
|
||||
res.ValidatedAt = int(s.ValidatedAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
59
userapi/storage/threepid/postgres/storage.go
Normal file
59
userapi/storage/threepid/postgres/storage.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/threepid/shared"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
*shared.Database
|
||||
}
|
||||
|
||||
const threePidSessionsSchema = `
|
||||
-- This sequence is used for automatic allocation of session_id.
|
||||
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS threepid_sessions (
|
||||
sid SERIAL PRIMARY KEY,
|
||||
client_secret VARCHAR(255),
|
||||
threepid TEXT ,
|
||||
token VARCHAR(255) ,
|
||||
next_link TEXT,
|
||||
validated_at_ts BIGINT,
|
||||
validated BOOLEAN,
|
||||
send_attempt INT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
|
||||
ON threepid_sessions (threepid, client_secret)
|
||||
`
|
||||
|
||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlutil.Open: %w", err)
|
||||
}
|
||||
|
||||
// Create the tables.
|
||||
if err := create(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sharedDb, err := shared.NewDatabase(db, sqlutil.NewDummyWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{
|
||||
Database: sharedDb,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func create(db *sql.DB) error {
|
||||
_, err := db.Exec(threePidSessionsSchema)
|
||||
return err
|
||||
}
|
||||
136
userapi/storage/threepid/shared/storage.go
Normal file
136
userapi/storage/threepid/shared/storage.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
writer sqlutil.Writer
|
||||
stm *ThreePidSessionStatements
|
||||
}
|
||||
|
||||
func (d *Database) InsertSession(
|
||||
ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error) {
|
||||
var sid int64
|
||||
return sid, d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
err := d.stm.insertSessionStmt.QueryRowContext(ctx, clientSecret, threepid, token, nextlink, validatedAt, validated, sendAttempt).Scan(&sid)
|
||||
// _, err := d.stm.insertSessionStmt.ExecContext(ctx, clientSecret, threepid, token, nextlink, sendAttempt, validatedAt, validated)
|
||||
return err
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// err = d.stm.selectSidStmt.QueryRowContext(ctx).Scan(&sid)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetSession(ctx context.Context, sid int64) (*api.Session, error) {
|
||||
s := api.Session{}
|
||||
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt)
|
||||
s.Sid = sid
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func (d *Database) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
|
||||
s := api.Session{}
|
||||
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
|
||||
QueryRowContext(ctx, threePid, ClientSecret).Scan(
|
||||
&s.Sid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt)
|
||||
s.ThreePid = threePid
|
||||
s.ClientSecret = ClientSecret
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func (d *Database) UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error {
|
||||
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) DeleteSession(ctx context.Context, sid int64) error {
|
||||
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) ValidateSession(ctx context.Context, sid int64, validatedAt int64) error {
|
||||
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func NewDatabase(db *sql.DB, writer sqlutil.Writer) (*Database, error) {
|
||||
threePidSessionsTable, err := PrepareThreePidSessionsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := Database{
|
||||
writer: writer,
|
||||
stm: threePidSessionsTable,
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
// func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||
// db, err := sqlutil.Open(dbProperties)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// writer := sqlutil.NewExclusiveWriter()
|
||||
// stmt := sessionStatements{
|
||||
// db: db,
|
||||
// writer: writer,
|
||||
// }
|
||||
|
||||
// // Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// // and THEN prepare statements so we don't fail due to referencing new columns
|
||||
// if err = stmt.execSchema(db); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if err = stmt.prepare(); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// handler := func(f func(tx *sql.Tx) error) error {
|
||||
// return writer.Do(nil, nil, f)
|
||||
// }
|
||||
// return &Db{db, writer, &stmt, handler}, nil
|
||||
// }
|
||||
|
||||
// func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||
// db, err := sqlutil.Open(dbProperties)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// stmt := sessionStatements{
|
||||
// db: db,
|
||||
// }
|
||||
// // Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// // and THEN prepare statements so we don't fail due to referencing new columns
|
||||
// if err = stmt.execSchema(db); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if err = stmt.prepare(); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// handler := func(f func(tx *sql.Tx) error) error {
|
||||
// return f(nil)
|
||||
// }
|
||||
// return &Db{db, nil, &stmt, handler}, nil
|
||||
// }
|
||||
|
||||
// func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
// switch {
|
||||
// case dbProperties.ConnectionString.IsSQLite():
|
||||
// return newSQLiteDatabase(dbProperties)
|
||||
// case dbProperties.ConnectionString.IsPostgres():
|
||||
// return newPostgresDatabase(dbProperties)
|
||||
// default:
|
||||
// return nil, fmt.Errorf("unexpected database type")
|
||||
// }
|
||||
// }
|
||||
49
userapi/storage/threepid/shared/threepid_sessions_table.go
Normal file
49
userapi/storage/threepid/shared/threepid_sessions_table.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const (
|
||||
insertSessionSQL = "" +
|
||||
"INSERT INTO threepid_sessions (client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7)" +
|
||||
"RETURNING sid;"
|
||||
// selectSidSQL = "" +
|
||||
// "SELECT last_insert_rowid();"
|
||||
selectSessionSQL = "" +
|
||||
"SELECT client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE sid = $1"
|
||||
selectSessionByThreePidAndCLientSecretSQL = "" +
|
||||
"SELECT sid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE threepid = $1 AND client_secret = $2"
|
||||
deleteSessionSQL = "" +
|
||||
"DELETE FROM threepid_sessions WHERE sid = $1"
|
||||
validateSessionSQL = "" +
|
||||
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
|
||||
updateSendAttemptNextLinkSQL = "" +
|
||||
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
|
||||
)
|
||||
|
||||
type ThreePidSessionStatements struct {
|
||||
insertSessionStmt *sql.Stmt
|
||||
// selectSidStmt *sql.Stmt
|
||||
selectSessionStmt *sql.Stmt
|
||||
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
|
||||
deleteSessionStmt *sql.Stmt
|
||||
validateSessionStmt *sql.Stmt
|
||||
updateSendAttemptNextLinkStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func PrepareThreePidSessionsTable(db *sql.DB) (*ThreePidSessionStatements, error) {
|
||||
s := ThreePidSessionStatements{}
|
||||
return &s, sqlutil.StatementList{
|
||||
{&s.insertSessionStmt, insertSessionSQL},
|
||||
// {&s.selectSidStmt, selectSidSQL},
|
||||
{&s.selectSessionStmt, selectSessionSQL},
|
||||
{&s.selectSessionByThreePidAndCLientSecretStmt, selectSessionByThreePidAndCLientSecretSQL},
|
||||
{&s.deleteSessionStmt, deleteSessionSQL},
|
||||
{&s.validateSessionStmt, validateSessionSQL},
|
||||
{&s.updateSendAttemptNextLinkStmt, updateSendAttemptNextLinkSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
59
userapi/storage/threepid/sqlite3/storage.go
Normal file
59
userapi/storage/threepid/sqlite3/storage.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/threepid/shared"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
*shared.Database
|
||||
}
|
||||
|
||||
const threePidSessionsSchema = `
|
||||
-- This sequence is used for automatic allocation of session_id.
|
||||
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS threepid_sessions (
|
||||
sid INTEGER PRIMARY KEY,
|
||||
client_secret VARCHAR(255),
|
||||
threepid TEXT ,
|
||||
token VARCHAR(255) ,
|
||||
next_link TEXT,
|
||||
validated_at_ts BIGINT,
|
||||
validated BOOLEAN,
|
||||
send_attempt INT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
|
||||
ON threepid_sessions (threepid, client_secret)
|
||||
`
|
||||
|
||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlutil.Open: %w", err)
|
||||
}
|
||||
|
||||
// Create the tables.
|
||||
if err := create(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sharedDb, err := shared.NewDatabase(db, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{
|
||||
Database: sharedDb,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func create(db *sql.DB) error {
|
||||
_, err := db.Exec(threePidSessionsSchema)
|
||||
return err
|
||||
}
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
package threepid
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const sessionsSchema = `
|
||||
-- This sequence is used for automatic allocation of session_id.
|
||||
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS threepid_sessions (
|
||||
sid VARCHAR(255) PRIMARY KEY,
|
||||
client_secret VARCHAR(255),
|
||||
threepid TEXT ,
|
||||
token VARCHAR(255) ,
|
||||
next_link TEXT,
|
||||
validated_at_ts BIGINT,
|
||||
validated BOOLEAN,
|
||||
send_attempt INT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
|
||||
ON threepid_sessions (threepid, client_secret)
|
||||
`
|
||||
|
||||
const (
|
||||
insertSessionSQL = "" +
|
||||
"INSERT INTO threepid_sessions (sid, client_secret, threepid, token, next_link, send_attempt, validated_at_ts, validated)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
selectSessionSQL = "" +
|
||||
"SELECT client_secret, threepid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE sid = $1"
|
||||
selectSessionByThreePidAndCLientSecretSQL = "" +
|
||||
"SELECT sid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE threepid = $1 AND client_secret = $2"
|
||||
deleteSessionSQL = "" +
|
||||
"DELETE FROM threepid_sessions WHERE sid = $1"
|
||||
validateSessionSQL = "" +
|
||||
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
|
||||
updateSendAttemptNextLinkSQL = "" +
|
||||
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
|
||||
)
|
||||
|
||||
type sessionStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertSessionStmt *sql.Stmt
|
||||
selectSessionStmt *sql.Stmt
|
||||
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
|
||||
deleteSessionStmt *sql.Stmt
|
||||
validateSessionStmt *sql.Stmt
|
||||
updateSendAttemptNextLinkStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *sessionStatements) prepare() (err error) {
|
||||
if s.insertSessionStmt, err = s.db.Prepare(insertSessionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectSessionStmt, err = s.db.Prepare(selectSessionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectSessionByThreePidAndCLientSecretStmt, err = s.db.Prepare(selectSessionByThreePidAndCLientSecretSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteSessionStmt, err = s.db.Prepare(deleteSessionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.validateSessionStmt, err = s.db.Prepare(validateSessionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateSendAttemptNextLinkStmt, err = s.db.Prepare(updateSendAttemptNextLinkSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *sessionStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(sessionsSchema)
|
||||
return err
|
||||
}
|
||||
|
|
@ -2,132 +2,30 @@ package threepid
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/threepid/postgres"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/threepid/sqlite3"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
InsertSession(context.Context, *api.Session) error
|
||||
GetSession(ctx context.Context, sid string) (*api.Session, error)
|
||||
InsertSession(ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error)
|
||||
GetSession(ctx context.Context, sid int64) (*api.Session, error)
|
||||
GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error)
|
||||
UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error
|
||||
DeleteSession(ctx context.Context, sid string) error
|
||||
ValidateSession(ctx context.Context, sid string, validatedAt int) error
|
||||
UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error
|
||||
DeleteSession(ctx context.Context, sid int64) error
|
||||
ValidateSession(ctx context.Context, sid int64, validatedAt int64) error
|
||||
}
|
||||
|
||||
type Db struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
stm *sessionStatements
|
||||
writeHandler func(func(*sql.Tx) error) error
|
||||
}
|
||||
|
||||
func (d *Db) InsertSession(ctx context.Context, s *api.Session) error {
|
||||
h := func(_ *sql.Tx) error {
|
||||
_, err := d.stm.insertSessionStmt.ExecContext(ctx, s.Sid, s.ClientSecret, s.ThreePid, s.Token, s.NextLink, s.SendAttempt, s.ValidatedAt, s.Validated)
|
||||
return err
|
||||
}
|
||||
return d.writeHandler(h)
|
||||
}
|
||||
|
||||
func (d *Db) GetSession(ctx context.Context, sid string) (*api.Session, error) {
|
||||
s := api.Session{}
|
||||
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
|
||||
s.Sid = sid
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func (d *Db) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
|
||||
s := api.Session{}
|
||||
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
|
||||
QueryRowContext(ctx, threePid, ClientSecret).Scan(
|
||||
&s.Sid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
|
||||
s.ThreePid = threePid
|
||||
s.ClientSecret = ClientSecret
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func (d *Db) UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error {
|
||||
h := func(_ *sql.Tx) error {
|
||||
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
|
||||
return err
|
||||
}
|
||||
return d.writeHandler(h)
|
||||
}
|
||||
|
||||
func (d *Db) DeleteSession(ctx context.Context, sid string) error {
|
||||
h := func(_ *sql.Tx) error {
|
||||
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
|
||||
return err
|
||||
}
|
||||
return d.writeHandler(h)
|
||||
}
|
||||
|
||||
func (d *Db) ValidateSession(ctx context.Context, sid string, validatedAt int) error {
|
||||
h := func(_ *sql.Tx) error {
|
||||
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
|
||||
return err
|
||||
}
|
||||
return d.writeHandler(h)
|
||||
}
|
||||
|
||||
func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer := sqlutil.NewExclusiveWriter()
|
||||
stmt := sessionStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = stmt.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = stmt.prepare(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler := func(f func(tx *sql.Tx) error) error {
|
||||
return writer.Do(nil, nil, f)
|
||||
}
|
||||
return &Db{db, writer, &stmt, handler}, nil
|
||||
}
|
||||
|
||||
func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sessionStatements{
|
||||
db: db,
|
||||
}
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = stmt.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = stmt.prepare(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler := func(f func(tx *sql.Tx) error) error {
|
||||
return f(nil)
|
||||
}
|
||||
return &Db{db, nil, &stmt, handler}, nil
|
||||
}
|
||||
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
// Open opens a database connection.
|
||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return newSQLiteDatabase(dbProperties)
|
||||
return sqlite3.Open(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return newPostgresDatabase(dbProperties)
|
||||
return postgres.Open(dbProperties)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import (
|
|||
)
|
||||
|
||||
var testSession = api.Session{
|
||||
// Sid: "123",
|
||||
ClientSecret: "099azAZ.=_-",
|
||||
ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com",
|
||||
Token: "fooBAR123",
|
||||
|
|
@ -26,15 +25,24 @@ var testSession = api.Session{
|
|||
|
||||
var testCtx = context.Background()
|
||||
|
||||
func mustNewDatabaseWithTestSession(is *is.I) *Db {
|
||||
func mustNewDatabaseWithTestSession(is *is.I) Database {
|
||||
randPostfix := strconv.Itoa(rand.Int())
|
||||
dbPath := os.TempDir() + "/dendrite-" + randPostfix
|
||||
println(dbPath)
|
||||
dut, err := newSQLiteDatabase(&config.DatabaseOptions{
|
||||
dut, err := Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource("file:" + dbPath),
|
||||
})
|
||||
is.NoErr(err)
|
||||
err = dut.InsertSession(testCtx, &testSession)
|
||||
sid, err := dut.InsertSession(
|
||||
testCtx,
|
||||
testSession.ClientSecret,
|
||||
testSession.ThreePid,
|
||||
testSession.Token,
|
||||
testSession.NextLink,
|
||||
testSession.ValidatedAt,
|
||||
testSession.Validated,
|
||||
testSession.SendAttempt)
|
||||
testSession.Sid = sid
|
||||
is.NoErr(err)
|
||||
return dut
|
||||
}
|
||||
|
|
@ -78,7 +86,7 @@ func TestDeleteSession(t *testing.T) {
|
|||
func TestValidateSession(t *testing.T) {
|
||||
is := is.New(t)
|
||||
dut := mustNewDatabaseWithTestSession(is)
|
||||
validatedAt := 1_623_406_296
|
||||
validatedAt := int64(1_623_406_296)
|
||||
err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt)
|
||||
is.NoErr(err)
|
||||
session, err := dut.GetSession(testCtx, testSession.Sid)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ func NewInternalAPI(
|
|||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to device db")
|
||||
}
|
||||
threepidDb, err := threepid.NewDatabase(&cfg.ThreepidDatabase)
|
||||
threepidDb, err := threepid.Open(&cfg.ThreepidDatabase)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to threepid db")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ func TestCreateSession_Twice(t *testing.T) {
|
|||
resp := api.CreateSessionResponse{}
|
||||
err := internalApi.CreateSession(ctx, testReq, &resp)
|
||||
is.NoErr(err)
|
||||
is.Equal(len(resp.Sid), 43)
|
||||
is.Equal(resp.Sid, int64(1))
|
||||
select {
|
||||
case <-mailer.c[api.AccountPassword]:
|
||||
t.Fatal("email was received, but sent attempt was not increased")
|
||||
|
|
@ -189,7 +189,7 @@ func TestCreateSession_Twice_IncreaseSendAttempt(t *testing.T) {
|
|||
testReqBumped.SendAttempt = 1
|
||||
err := internalApi.CreateSession(ctx, &testReqBumped, &resp)
|
||||
is.NoErr(err)
|
||||
is.Equal(len(resp.Sid), 43)
|
||||
is.Equal(resp.Sid, int64(1))
|
||||
sub := <-mailer.c[api.AccountPassword]
|
||||
is.Equal(len(sub.Token), 64)
|
||||
is.Equal(sub.To, testReq.ThreePid)
|
||||
|
|
@ -235,7 +235,7 @@ func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateS
|
|||
i.Mail = mailer
|
||||
err := i.CreateSession(ctx, testReq, resp)
|
||||
is.NoErr(err)
|
||||
is.Equal(len(resp.Sid), 43)
|
||||
is.Equal(resp.Sid, int64(1))
|
||||
sub := <-mailer.c[api.AccountPassword]
|
||||
is.Equal(len(sub.Token), 64)
|
||||
is.Equal(sub.To, testReq.ThreePid)
|
||||
|
|
@ -244,14 +244,14 @@ func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateS
|
|||
is.Equal(submitUrl.Host, "example.com")
|
||||
is.Equal(submitUrl.Path, "/_matrix/client/r0/account/password/email/submitToken")
|
||||
q := submitUrl.Query()
|
||||
is.Equal(len(q["sid"][0]), 43)
|
||||
is.Equal(q["sid"][0], "1")
|
||||
is.Equal(q["token"][0], sub.Token)
|
||||
is.Equal(q["client_secret"][0], "foobar")
|
||||
token = sub.Token
|
||||
return
|
||||
}
|
||||
|
||||
func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token, sid string) {
|
||||
func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token string, sid int64) {
|
||||
err := i.ValidateSession(ctx, &api.ValidateSessionRequest{
|
||||
SessionOwnership: api.SessionOwnership{
|
||||
Sid: sid,
|
||||
|
|
|
|||
Loading…
Reference in a new issue