From 316adf66d1ea2ac0a8375b40b10958b37b91bd8b Mon Sep 17 00:00:00 2001 From: Piotr Kozimor Date: Mon, 13 Sep 2021 17:01:14 +0200 Subject: [PATCH] Refactor threepid sessions storage so that postgres and sqlite3 schemas are separate. Generate session id in database --- userapi/api/api.go | 13 +- userapi/internal/threepid.go | 14 +- userapi/storage/threepid/postgres/storage.go | 59 ++++++++ userapi/storage/threepid/shared/storage.go | 136 ++++++++++++++++++ .../shared/threepid_sessions_table.go | 49 +++++++ userapi/storage/threepid/sqlite3/storage.go | 59 ++++++++ userapi/storage/threepid/stmt.go | 81 ----------- userapi/storage/threepid/storage.go | 124 ++-------------- userapi/storage/threepid/storage_test.go | 18 ++- userapi/userapi.go | 2 +- userapi/userapi_test.go | 10 +- 11 files changed, 346 insertions(+), 219 deletions(-) create mode 100644 userapi/storage/threepid/postgres/storage.go create mode 100644 userapi/storage/threepid/shared/storage.go create mode 100644 userapi/storage/threepid/shared/threepid_sessions_table.go create mode 100644 userapi/storage/threepid/sqlite3/storage.go delete mode 100644 userapi/storage/threepid/stmt.go diff --git a/userapi/api/api.go b/userapi/api/api.go index 22d7a6779..ea6f23a9f 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -438,7 +438,7 @@ type CreateSessionRequest struct { } type CreateSessionResponse struct { - Sid string + Sid int64 } type ValidateSessionRequest struct { @@ -451,13 +451,16 @@ type GetThreePidForSessionResponse struct { } type SessionOwnership struct { - Sid, ClientSecret string + Sid int64 + ClientSecret string } type Session struct { - Sid, ClientSecret, ThreePid, Token, NextLink string - SendAttempt, ValidatedAt int - Validated bool + ClientSecret, ThreePid, Token, NextLink string + Sid int64 + SendAttempt int + ValidatedAt int64 + Validated bool } type IsSessionValidatedResponse struct { diff --git a/userapi/internal/threepid.go b/userapi/internal/threepid.go index e793677a8..a06816739 100644 --- a/userapi/internal/threepid.go +++ b/userapi/internal/threepid.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "net/url" + "strconv" "time" "github.com/matrix-org/dendrite/internal" @@ -28,19 +29,14 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess if errB != nil { return errB } - sid, errB := internal.GenerateBlob(sessionIdByteLength) - if errB != nil { - return errB - } s = &api.Session{ - Sid: sid, ClientSecret: req.ClientSecret, ThreePid: req.ThreePid, SendAttempt: req.SendAttempt, Token: token, NextLink: req.NextLink, } - err = a.ThreePidDB.InsertSession(ctx, s) + s.Sid, err = a.ThreePidDB.InsertSession(ctx, req.ClientSecret, req.ThreePid, token, req.NextLink, 0, false, req.SendAttempt) if err != nil { return err } @@ -60,7 +56,7 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess } res.Sid = s.Sid query := url.Values{ - "sid": []string{s.Sid}, + "sid": []string{strconv.Itoa(int(s.Sid))}, "client_secret": []string{s.ClientSecret}, "token": []string{s.Token}, } @@ -88,7 +84,7 @@ func (a *UserInternalAPI) ValidateSession(ctx context.Context, req *api.Validate if s.Token != req.Token { return ErrBadSession } - return a.ThreePidDB.ValidateSession(ctx, s.Sid, int(time.Now().Unix())) + return a.ThreePidDB.ValidateSession(ctx, s.Sid, time.Now().Unix()) } func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error { @@ -114,7 +110,7 @@ func (a *UserInternalAPI) IsSessionValidated(ctx context.Context, req *api.Sessi return err } res.Validated = s.Validated - res.ValidatedAt = s.ValidatedAt + res.ValidatedAt = int(s.ValidatedAt) return nil } diff --git a/userapi/storage/threepid/postgres/storage.go b/userapi/storage/threepid/postgres/storage.go new file mode 100644 index 000000000..ab2c88c08 --- /dev/null +++ b/userapi/storage/threepid/postgres/storage.go @@ -0,0 +1,59 @@ +package postgres + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/threepid/shared" +) + +type Database struct { + *shared.Database +} + +const threePidSessionsSchema = ` +-- This sequence is used for automatic allocation of session_id. +-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1; + +-- Stores data about devices. +CREATE TABLE IF NOT EXISTS threepid_sessions ( + sid SERIAL PRIMARY KEY, + client_secret VARCHAR(255), + threepid TEXT , + token VARCHAR(255) , + next_link TEXT, + validated_at_ts BIGINT, + validated BOOLEAN, + send_attempt INT +); + +CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids + ON threepid_sessions (threepid, client_secret) +` + +func Open(dbProperties *config.DatabaseOptions) (*Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, fmt.Errorf("sqlutil.Open: %w", err) + } + + // Create the tables. + if err := create(db); err != nil { + return nil, err + } + + sharedDb, err := shared.NewDatabase(db, sqlutil.NewDummyWriter()) + if err != nil { + return nil, err + } + return &Database{ + Database: sharedDb, + }, nil +} + +func create(db *sql.DB) error { + _, err := db.Exec(threePidSessionsSchema) + return err +} diff --git a/userapi/storage/threepid/shared/storage.go b/userapi/storage/threepid/shared/storage.go new file mode 100644 index 000000000..7e3749832 --- /dev/null +++ b/userapi/storage/threepid/shared/storage.go @@ -0,0 +1,136 @@ +package shared + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" +) + +type Database struct { + writer sqlutil.Writer + stm *ThreePidSessionStatements +} + +func (d *Database) InsertSession( + ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error) { + var sid int64 + return sid, d.writer.Do(nil, nil, func(_ *sql.Tx) error { + err := d.stm.insertSessionStmt.QueryRowContext(ctx, clientSecret, threepid, token, nextlink, validatedAt, validated, sendAttempt).Scan(&sid) + // _, err := d.stm.insertSessionStmt.ExecContext(ctx, clientSecret, threepid, token, nextlink, sendAttempt, validatedAt, validated) + return err + // if err != nil { + // return err + // } + // err = d.stm.selectSidStmt.QueryRowContext(ctx).Scan(&sid) + }) +} + +func (d *Database) GetSession(ctx context.Context, sid int64) (*api.Session, error) { + s := api.Session{} + err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt) + s.Sid = sid + return &s, err +} + +func (d *Database) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) { + s := api.Session{} + err := d.stm.selectSessionByThreePidAndCLientSecretStmt. + QueryRowContext(ctx, threePid, ClientSecret).Scan( + &s.Sid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt) + s.ThreePid = threePid + s.ClientSecret = ClientSecret + return &s, err +} + +func (d *Database) UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error { + return d.writer.Do(nil, nil, func(_ *sql.Tx) error { + _, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid) + return err + }) +} + +func (d *Database) DeleteSession(ctx context.Context, sid int64) error { + return d.writer.Do(nil, nil, func(_ *sql.Tx) error { + _, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid) + return err + }) +} + +func (d *Database) ValidateSession(ctx context.Context, sid int64, validatedAt int64) error { + return d.writer.Do(nil, nil, func(_ *sql.Tx) error { + _, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid) + return err + }) +} + +func NewDatabase(db *sql.DB, writer sqlutil.Writer) (*Database, error) { + threePidSessionsTable, err := PrepareThreePidSessionsTable(db) + if err != nil { + return nil, err + } + d := Database{ + writer: writer, + stm: threePidSessionsTable, + } + return &d, nil +} + +// func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) { +// db, err := sqlutil.Open(dbProperties) +// if err != nil { +// return nil, err +// } +// writer := sqlutil.NewExclusiveWriter() +// stmt := sessionStatements{ +// db: db, +// writer: writer, +// } + +// // Create tables before executing migrations so we don't fail if the table is missing, +// // and THEN prepare statements so we don't fail due to referencing new columns +// if err = stmt.execSchema(db); err != nil { +// return nil, err +// } +// if err = stmt.prepare(); err != nil { +// return nil, err +// } +// handler := func(f func(tx *sql.Tx) error) error { +// return writer.Do(nil, nil, f) +// } +// return &Db{db, writer, &stmt, handler}, nil +// } + +// func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) { +// db, err := sqlutil.Open(dbProperties) +// if err != nil { +// return nil, err +// } +// stmt := sessionStatements{ +// db: db, +// } +// // Create tables before executing migrations so we don't fail if the table is missing, +// // and THEN prepare statements so we don't fail due to referencing new columns +// if err = stmt.execSchema(db); err != nil { +// return nil, err +// } +// if err = stmt.prepare(); err != nil { +// return nil, err +// } +// handler := func(f func(tx *sql.Tx) error) error { +// return f(nil) +// } +// return &Db{db, nil, &stmt, handler}, nil +// } + +// func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { +// switch { +// case dbProperties.ConnectionString.IsSQLite(): +// return newSQLiteDatabase(dbProperties) +// case dbProperties.ConnectionString.IsPostgres(): +// return newPostgresDatabase(dbProperties) +// default: +// return nil, fmt.Errorf("unexpected database type") +// } +// } diff --git a/userapi/storage/threepid/shared/threepid_sessions_table.go b/userapi/storage/threepid/shared/threepid_sessions_table.go new file mode 100644 index 000000000..8a62982c0 --- /dev/null +++ b/userapi/storage/threepid/shared/threepid_sessions_table.go @@ -0,0 +1,49 @@ +package shared + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const ( + insertSessionSQL = "" + + "INSERT INTO threepid_sessions (client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7)" + + "RETURNING sid;" + // selectSidSQL = "" + + // "SELECT last_insert_rowid();" + selectSessionSQL = "" + + "SELECT client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE sid = $1" + selectSessionByThreePidAndCLientSecretSQL = "" + + "SELECT sid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE threepid = $1 AND client_secret = $2" + deleteSessionSQL = "" + + "DELETE FROM threepid_sessions WHERE sid = $1" + validateSessionSQL = "" + + "UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3" + updateSendAttemptNextLinkSQL = "" + + "UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2" +) + +type ThreePidSessionStatements struct { + insertSessionStmt *sql.Stmt + // selectSidStmt *sql.Stmt + selectSessionStmt *sql.Stmt + selectSessionByThreePidAndCLientSecretStmt *sql.Stmt + deleteSessionStmt *sql.Stmt + validateSessionStmt *sql.Stmt + updateSendAttemptNextLinkStmt *sql.Stmt +} + +func PrepareThreePidSessionsTable(db *sql.DB) (*ThreePidSessionStatements, error) { + s := ThreePidSessionStatements{} + return &s, sqlutil.StatementList{ + {&s.insertSessionStmt, insertSessionSQL}, + // {&s.selectSidStmt, selectSidSQL}, + {&s.selectSessionStmt, selectSessionSQL}, + {&s.selectSessionByThreePidAndCLientSecretStmt, selectSessionByThreePidAndCLientSecretSQL}, + {&s.deleteSessionStmt, deleteSessionSQL}, + {&s.validateSessionStmt, validateSessionSQL}, + {&s.updateSendAttemptNextLinkStmt, updateSendAttemptNextLinkSQL}, + }.Prepare(db) +} diff --git a/userapi/storage/threepid/sqlite3/storage.go b/userapi/storage/threepid/sqlite3/storage.go new file mode 100644 index 000000000..22b3e36df --- /dev/null +++ b/userapi/storage/threepid/sqlite3/storage.go @@ -0,0 +1,59 @@ +package sqlite3 + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/threepid/shared" +) + +type Database struct { + *shared.Database +} + +const threePidSessionsSchema = ` +-- This sequence is used for automatic allocation of session_id. +-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1; + +-- Stores data about devices. +CREATE TABLE IF NOT EXISTS threepid_sessions ( + sid INTEGER PRIMARY KEY, + client_secret VARCHAR(255), + threepid TEXT , + token VARCHAR(255) , + next_link TEXT, + validated_at_ts BIGINT, + validated BOOLEAN, + send_attempt INT +); + +CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids + ON threepid_sessions (threepid, client_secret) +` + +func Open(dbProperties *config.DatabaseOptions) (*Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, fmt.Errorf("sqlutil.Open: %w", err) + } + + // Create the tables. + if err := create(db); err != nil { + return nil, err + } + + sharedDb, err := shared.NewDatabase(db, sqlutil.NewExclusiveWriter()) + if err != nil { + return nil, err + } + return &Database{ + Database: sharedDb, + }, nil +} + +func create(db *sql.DB) error { + _, err := db.Exec(threePidSessionsSchema) + return err +} diff --git a/userapi/storage/threepid/stmt.go b/userapi/storage/threepid/stmt.go deleted file mode 100644 index 02ff1c443..000000000 --- a/userapi/storage/threepid/stmt.go +++ /dev/null @@ -1,81 +0,0 @@ -package threepid - -import ( - "database/sql" - - "github.com/matrix-org/dendrite/internal/sqlutil" -) - -const sessionsSchema = ` --- This sequence is used for automatic allocation of session_id. --- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1; - --- Stores data about devices. -CREATE TABLE IF NOT EXISTS threepid_sessions ( - sid VARCHAR(255) PRIMARY KEY, - client_secret VARCHAR(255), - threepid TEXT , - token VARCHAR(255) , - next_link TEXT, - validated_at_ts BIGINT, - validated BOOLEAN, - send_attempt INT -); - -CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids - ON threepid_sessions (threepid, client_secret) -` - -const ( - insertSessionSQL = "" + - "INSERT INTO threepid_sessions (sid, client_secret, threepid, token, next_link, send_attempt, validated_at_ts, validated)" + - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" - selectSessionSQL = "" + - "SELECT client_secret, threepid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE sid = $1" - selectSessionByThreePidAndCLientSecretSQL = "" + - "SELECT sid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE threepid = $1 AND client_secret = $2" - deleteSessionSQL = "" + - "DELETE FROM threepid_sessions WHERE sid = $1" - validateSessionSQL = "" + - "UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3" - updateSendAttemptNextLinkSQL = "" + - "UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2" -) - -type sessionStatements struct { - db *sql.DB - writer sqlutil.Writer - insertSessionStmt *sql.Stmt - selectSessionStmt *sql.Stmt - selectSessionByThreePidAndCLientSecretStmt *sql.Stmt - deleteSessionStmt *sql.Stmt - validateSessionStmt *sql.Stmt - updateSendAttemptNextLinkStmt *sql.Stmt -} - -func (s *sessionStatements) prepare() (err error) { - if s.insertSessionStmt, err = s.db.Prepare(insertSessionSQL); err != nil { - return - } - if s.selectSessionStmt, err = s.db.Prepare(selectSessionSQL); err != nil { - return - } - if s.selectSessionByThreePidAndCLientSecretStmt, err = s.db.Prepare(selectSessionByThreePidAndCLientSecretSQL); err != nil { - return - } - if s.deleteSessionStmt, err = s.db.Prepare(deleteSessionSQL); err != nil { - return - } - if s.validateSessionStmt, err = s.db.Prepare(validateSessionSQL); err != nil { - return - } - if s.updateSendAttemptNextLinkStmt, err = s.db.Prepare(updateSendAttemptNextLinkSQL); err != nil { - return - } - return -} - -func (s *sessionStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(sessionsSchema) - return err -} diff --git a/userapi/storage/threepid/storage.go b/userapi/storage/threepid/storage.go index a3cafa839..dd3474ac9 100644 --- a/userapi/storage/threepid/storage.go +++ b/userapi/storage/threepid/storage.go @@ -2,132 +2,30 @@ package threepid import ( "context" - "database/sql" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/threepid/postgres" + "github.com/matrix-org/dendrite/userapi/storage/threepid/sqlite3" ) type Database interface { - InsertSession(context.Context, *api.Session) error - GetSession(ctx context.Context, sid string) (*api.Session, error) + InsertSession(ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error) + GetSession(ctx context.Context, sid int64) (*api.Session, error) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) - UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error - DeleteSession(ctx context.Context, sid string) error - ValidateSession(ctx context.Context, sid string, validatedAt int) error + UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error + DeleteSession(ctx context.Context, sid int64) error + ValidateSession(ctx context.Context, sid int64, validatedAt int64) error } -type Db struct { - db *sql.DB - writer sqlutil.Writer - stm *sessionStatements - writeHandler func(func(*sql.Tx) error) error -} - -func (d *Db) InsertSession(ctx context.Context, s *api.Session) error { - h := func(_ *sql.Tx) error { - _, err := d.stm.insertSessionStmt.ExecContext(ctx, s.Sid, s.ClientSecret, s.ThreePid, s.Token, s.NextLink, s.SendAttempt, s.ValidatedAt, s.Validated) - return err - } - return d.writeHandler(h) -} - -func (d *Db) GetSession(ctx context.Context, sid string) (*api.Session, error) { - s := api.Session{} - err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt) - s.Sid = sid - return &s, err -} - -func (d *Db) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) { - s := api.Session{} - err := d.stm.selectSessionByThreePidAndCLientSecretStmt. - QueryRowContext(ctx, threePid, ClientSecret).Scan( - &s.Sid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt) - s.ThreePid = threePid - s.ClientSecret = ClientSecret - return &s, err -} - -func (d *Db) UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error { - h := func(_ *sql.Tx) error { - _, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid) - return err - } - return d.writeHandler(h) -} - -func (d *Db) DeleteSession(ctx context.Context, sid string) error { - h := func(_ *sql.Tx) error { - _, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid) - return err - } - return d.writeHandler(h) -} - -func (d *Db) ValidateSession(ctx context.Context, sid string, validatedAt int) error { - h := func(_ *sql.Tx) error { - _, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid) - return err - } - return d.writeHandler(h) -} - -func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - writer := sqlutil.NewExclusiveWriter() - stmt := sessionStatements{ - db: db, - writer: writer, - } - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = stmt.execSchema(db); err != nil { - return nil, err - } - if err = stmt.prepare(); err != nil { - return nil, err - } - handler := func(f func(tx *sql.Tx) error) error { - return writer.Do(nil, nil, f) - } - return &Db{db, writer, &stmt, handler}, nil -} - -func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - stmt := sessionStatements{ - db: db, - } - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = stmt.execSchema(db); err != nil { - return nil, err - } - if err = stmt.prepare(); err != nil { - return nil, err - } - handler := func(f func(tx *sql.Tx) error) error { - return f(nil) - } - return &Db{db, nil, &stmt, handler}, nil -} - -func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { +// Open opens a database connection. +func Open(dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return newSQLiteDatabase(dbProperties) + return sqlite3.Open(dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return newPostgresDatabase(dbProperties) + return postgres.Open(dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/threepid/storage_test.go b/userapi/storage/threepid/storage_test.go index ed9f79a84..bc1e0e039 100644 --- a/userapi/storage/threepid/storage_test.go +++ b/userapi/storage/threepid/storage_test.go @@ -14,7 +14,6 @@ import ( ) var testSession = api.Session{ - // Sid: "123", ClientSecret: "099azAZ.=_-", ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com", Token: "fooBAR123", @@ -26,15 +25,24 @@ var testSession = api.Session{ var testCtx = context.Background() -func mustNewDatabaseWithTestSession(is *is.I) *Db { +func mustNewDatabaseWithTestSession(is *is.I) Database { randPostfix := strconv.Itoa(rand.Int()) dbPath := os.TempDir() + "/dendrite-" + randPostfix println(dbPath) - dut, err := newSQLiteDatabase(&config.DatabaseOptions{ + dut, err := Open(&config.DatabaseOptions{ ConnectionString: config.DataSource("file:" + dbPath), }) is.NoErr(err) - err = dut.InsertSession(testCtx, &testSession) + sid, err := dut.InsertSession( + testCtx, + testSession.ClientSecret, + testSession.ThreePid, + testSession.Token, + testSession.NextLink, + testSession.ValidatedAt, + testSession.Validated, + testSession.SendAttempt) + testSession.Sid = sid is.NoErr(err) return dut } @@ -78,7 +86,7 @@ func TestDeleteSession(t *testing.T) { func TestValidateSession(t *testing.T) { is := is.New(t) dut := mustNewDatabaseWithTestSession(is) - validatedAt := 1_623_406_296 + validatedAt := int64(1_623_406_296) err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt) is.NoErr(err) session, err := dut.GetSession(testCtx, testSession.Sid) diff --git a/userapi/userapi.go b/userapi/userapi.go index b620d3ec4..d404b92f8 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -43,7 +43,7 @@ func NewInternalAPI( if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } - threepidDb, err := threepid.NewDatabase(&cfg.ThreepidDatabase) + threepidDb, err := threepid.Open(&cfg.ThreepidDatabase) if err != nil { logrus.WithError(err).Panicf("failed to connect to threepid db") } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 0567bf66c..66ae44abc 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -171,7 +171,7 @@ func TestCreateSession_Twice(t *testing.T) { resp := api.CreateSessionResponse{} err := internalApi.CreateSession(ctx, testReq, &resp) is.NoErr(err) - is.Equal(len(resp.Sid), 43) + is.Equal(resp.Sid, int64(1)) select { case <-mailer.c[api.AccountPassword]: t.Fatal("email was received, but sent attempt was not increased") @@ -189,7 +189,7 @@ func TestCreateSession_Twice_IncreaseSendAttempt(t *testing.T) { testReqBumped.SendAttempt = 1 err := internalApi.CreateSession(ctx, &testReqBumped, &resp) is.NoErr(err) - is.Equal(len(resp.Sid), 43) + is.Equal(resp.Sid, int64(1)) sub := <-mailer.c[api.AccountPassword] is.Equal(len(sub.Token), 64) is.Equal(sub.To, testReq.ThreePid) @@ -235,7 +235,7 @@ func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateS i.Mail = mailer err := i.CreateSession(ctx, testReq, resp) is.NoErr(err) - is.Equal(len(resp.Sid), 43) + is.Equal(resp.Sid, int64(1)) sub := <-mailer.c[api.AccountPassword] is.Equal(len(sub.Token), 64) is.Equal(sub.To, testReq.ThreePid) @@ -244,14 +244,14 @@ func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateS is.Equal(submitUrl.Host, "example.com") is.Equal(submitUrl.Path, "/_matrix/client/r0/account/password/email/submitToken") q := submitUrl.Query() - is.Equal(len(q["sid"][0]), 43) + is.Equal(q["sid"][0], "1") is.Equal(q["token"][0], sub.Token) is.Equal(q["client_secret"][0], "foobar") token = sub.Token return } -func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token, sid string) { +func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token string, sid int64) { err := i.ValidateSession(ctx, &api.ValidateSessionRequest{ SessionOwnership: api.SessionOwnership{ Sid: sid,