Refactor threepid sessions storage so that postgres and sqlite3 schemas are separate. Generate session id in database

This commit is contained in:
Piotr Kozimor 2021-09-13 17:01:14 +02:00
parent bae4313b4b
commit 316adf66d1
11 changed files with 346 additions and 219 deletions

View file

@ -438,7 +438,7 @@ type CreateSessionRequest struct {
} }
type CreateSessionResponse struct { type CreateSessionResponse struct {
Sid string Sid int64
} }
type ValidateSessionRequest struct { type ValidateSessionRequest struct {
@ -451,13 +451,16 @@ type GetThreePidForSessionResponse struct {
} }
type SessionOwnership struct { type SessionOwnership struct {
Sid, ClientSecret string Sid int64
ClientSecret string
} }
type Session struct { type Session struct {
Sid, ClientSecret, ThreePid, Token, NextLink string ClientSecret, ThreePid, Token, NextLink string
SendAttempt, ValidatedAt int Sid int64
Validated bool SendAttempt int
ValidatedAt int64
Validated bool
} }
type IsSessionValidatedResponse struct { type IsSessionValidatedResponse struct {

View file

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"net/url" "net/url"
"strconv"
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -28,19 +29,14 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess
if errB != nil { if errB != nil {
return errB return errB
} }
sid, errB := internal.GenerateBlob(sessionIdByteLength)
if errB != nil {
return errB
}
s = &api.Session{ s = &api.Session{
Sid: sid,
ClientSecret: req.ClientSecret, ClientSecret: req.ClientSecret,
ThreePid: req.ThreePid, ThreePid: req.ThreePid,
SendAttempt: req.SendAttempt, SendAttempt: req.SendAttempt,
Token: token, Token: token,
NextLink: req.NextLink, NextLink: req.NextLink,
} }
err = a.ThreePidDB.InsertSession(ctx, s) s.Sid, err = a.ThreePidDB.InsertSession(ctx, req.ClientSecret, req.ThreePid, token, req.NextLink, 0, false, req.SendAttempt)
if err != nil { if err != nil {
return err return err
} }
@ -60,7 +56,7 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess
} }
res.Sid = s.Sid res.Sid = s.Sid
query := url.Values{ query := url.Values{
"sid": []string{s.Sid}, "sid": []string{strconv.Itoa(int(s.Sid))},
"client_secret": []string{s.ClientSecret}, "client_secret": []string{s.ClientSecret},
"token": []string{s.Token}, "token": []string{s.Token},
} }
@ -88,7 +84,7 @@ func (a *UserInternalAPI) ValidateSession(ctx context.Context, req *api.Validate
if s.Token != req.Token { if s.Token != req.Token {
return ErrBadSession return ErrBadSession
} }
return a.ThreePidDB.ValidateSession(ctx, s.Sid, int(time.Now().Unix())) return a.ThreePidDB.ValidateSession(ctx, s.Sid, time.Now().Unix())
} }
func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error { func (a *UserInternalAPI) GetThreePidForSession(ctx context.Context, req *api.SessionOwnership, res *api.GetThreePidForSessionResponse) error {
@ -114,7 +110,7 @@ func (a *UserInternalAPI) IsSessionValidated(ctx context.Context, req *api.Sessi
return err return err
} }
res.Validated = s.Validated res.Validated = s.Validated
res.ValidatedAt = s.ValidatedAt res.ValidatedAt = int(s.ValidatedAt)
return nil return nil
} }

View file

@ -0,0 +1,59 @@
package postgres
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/threepid/shared"
)
type Database struct {
*shared.Database
}
const threePidSessionsSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS threepid_sessions (
sid SERIAL PRIMARY KEY,
client_secret VARCHAR(255),
threepid TEXT ,
token VARCHAR(255) ,
next_link TEXT,
validated_at_ts BIGINT,
validated BOOLEAN,
send_attempt INT
);
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
ON threepid_sessions (threepid, client_secret)
`
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, fmt.Errorf("sqlutil.Open: %w", err)
}
// Create the tables.
if err := create(db); err != nil {
return nil, err
}
sharedDb, err := shared.NewDatabase(db, sqlutil.NewDummyWriter())
if err != nil {
return nil, err
}
return &Database{
Database: sharedDb,
}, nil
}
func create(db *sql.DB) error {
_, err := db.Exec(threePidSessionsSchema)
return err
}

View file

@ -0,0 +1,136 @@
package shared
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database struct {
writer sqlutil.Writer
stm *ThreePidSessionStatements
}
func (d *Database) InsertSession(
ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error) {
var sid int64
return sid, d.writer.Do(nil, nil, func(_ *sql.Tx) error {
err := d.stm.insertSessionStmt.QueryRowContext(ctx, clientSecret, threepid, token, nextlink, validatedAt, validated, sendAttempt).Scan(&sid)
// _, err := d.stm.insertSessionStmt.ExecContext(ctx, clientSecret, threepid, token, nextlink, sendAttempt, validatedAt, validated)
return err
// if err != nil {
// return err
// }
// err = d.stm.selectSidStmt.QueryRowContext(ctx).Scan(&sid)
})
}
func (d *Database) GetSession(ctx context.Context, sid int64) (*api.Session, error) {
s := api.Session{}
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt)
s.Sid = sid
return &s, err
}
func (d *Database) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
s := api.Session{}
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
QueryRowContext(ctx, threePid, ClientSecret).Scan(
&s.Sid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt)
s.ThreePid = threePid
s.ClientSecret = ClientSecret
return &s, err
}
func (d *Database) UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error {
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
return err
})
}
func (d *Database) DeleteSession(ctx context.Context, sid int64) error {
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
return err
})
}
func (d *Database) ValidateSession(ctx context.Context, sid int64, validatedAt int64) error {
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
return err
})
}
func NewDatabase(db *sql.DB, writer sqlutil.Writer) (*Database, error) {
threePidSessionsTable, err := PrepareThreePidSessionsTable(db)
if err != nil {
return nil, err
}
d := Database{
writer: writer,
stm: threePidSessionsTable,
}
return &d, nil
}
// func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
// db, err := sqlutil.Open(dbProperties)
// if err != nil {
// return nil, err
// }
// writer := sqlutil.NewExclusiveWriter()
// stmt := sessionStatements{
// db: db,
// writer: writer,
// }
// // Create tables before executing migrations so we don't fail if the table is missing,
// // and THEN prepare statements so we don't fail due to referencing new columns
// if err = stmt.execSchema(db); err != nil {
// return nil, err
// }
// if err = stmt.prepare(); err != nil {
// return nil, err
// }
// handler := func(f func(tx *sql.Tx) error) error {
// return writer.Do(nil, nil, f)
// }
// return &Db{db, writer, &stmt, handler}, nil
// }
// func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
// db, err := sqlutil.Open(dbProperties)
// if err != nil {
// return nil, err
// }
// stmt := sessionStatements{
// db: db,
// }
// // Create tables before executing migrations so we don't fail if the table is missing,
// // and THEN prepare statements so we don't fail due to referencing new columns
// if err = stmt.execSchema(db); err != nil {
// return nil, err
// }
// if err = stmt.prepare(); err != nil {
// return nil, err
// }
// handler := func(f func(tx *sql.Tx) error) error {
// return f(nil)
// }
// return &Db{db, nil, &stmt, handler}, nil
// }
// func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
// switch {
// case dbProperties.ConnectionString.IsSQLite():
// return newSQLiteDatabase(dbProperties)
// case dbProperties.ConnectionString.IsPostgres():
// return newPostgresDatabase(dbProperties)
// default:
// return nil, fmt.Errorf("unexpected database type")
// }
// }

View file

@ -0,0 +1,49 @@
package shared
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const (
insertSessionSQL = "" +
"INSERT INTO threepid_sessions (client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7)" +
"RETURNING sid;"
// selectSidSQL = "" +
// "SELECT last_insert_rowid();"
selectSessionSQL = "" +
"SELECT client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE sid = $1"
selectSessionByThreePidAndCLientSecretSQL = "" +
"SELECT sid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE threepid = $1 AND client_secret = $2"
deleteSessionSQL = "" +
"DELETE FROM threepid_sessions WHERE sid = $1"
validateSessionSQL = "" +
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
updateSendAttemptNextLinkSQL = "" +
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
)
type ThreePidSessionStatements struct {
insertSessionStmt *sql.Stmt
// selectSidStmt *sql.Stmt
selectSessionStmt *sql.Stmt
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
deleteSessionStmt *sql.Stmt
validateSessionStmt *sql.Stmt
updateSendAttemptNextLinkStmt *sql.Stmt
}
func PrepareThreePidSessionsTable(db *sql.DB) (*ThreePidSessionStatements, error) {
s := ThreePidSessionStatements{}
return &s, sqlutil.StatementList{
{&s.insertSessionStmt, insertSessionSQL},
// {&s.selectSidStmt, selectSidSQL},
{&s.selectSessionStmt, selectSessionSQL},
{&s.selectSessionByThreePidAndCLientSecretStmt, selectSessionByThreePidAndCLientSecretSQL},
{&s.deleteSessionStmt, deleteSessionSQL},
{&s.validateSessionStmt, validateSessionSQL},
{&s.updateSendAttemptNextLinkStmt, updateSendAttemptNextLinkSQL},
}.Prepare(db)
}

View file

@ -0,0 +1,59 @@
package sqlite3
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/threepid/shared"
)
type Database struct {
*shared.Database
}
const threePidSessionsSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS threepid_sessions (
sid INTEGER PRIMARY KEY,
client_secret VARCHAR(255),
threepid TEXT ,
token VARCHAR(255) ,
next_link TEXT,
validated_at_ts BIGINT,
validated BOOLEAN,
send_attempt INT
);
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
ON threepid_sessions (threepid, client_secret)
`
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, fmt.Errorf("sqlutil.Open: %w", err)
}
// Create the tables.
if err := create(db); err != nil {
return nil, err
}
sharedDb, err := shared.NewDatabase(db, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
}
return &Database{
Database: sharedDb,
}, nil
}
func create(db *sql.DB) error {
_, err := db.Exec(threePidSessionsSchema)
return err
}

View file

@ -1,81 +0,0 @@
package threepid
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const sessionsSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS threepid_sessions (
sid VARCHAR(255) PRIMARY KEY,
client_secret VARCHAR(255),
threepid TEXT ,
token VARCHAR(255) ,
next_link TEXT,
validated_at_ts BIGINT,
validated BOOLEAN,
send_attempt INT
);
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
ON threepid_sessions (threepid, client_secret)
`
const (
insertSessionSQL = "" +
"INSERT INTO threepid_sessions (sid, client_secret, threepid, token, next_link, send_attempt, validated_at_ts, validated)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
selectSessionSQL = "" +
"SELECT client_secret, threepid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE sid = $1"
selectSessionByThreePidAndCLientSecretSQL = "" +
"SELECT sid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE threepid = $1 AND client_secret = $2"
deleteSessionSQL = "" +
"DELETE FROM threepid_sessions WHERE sid = $1"
validateSessionSQL = "" +
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
updateSendAttemptNextLinkSQL = "" +
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
)
type sessionStatements struct {
db *sql.DB
writer sqlutil.Writer
insertSessionStmt *sql.Stmt
selectSessionStmt *sql.Stmt
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
deleteSessionStmt *sql.Stmt
validateSessionStmt *sql.Stmt
updateSendAttemptNextLinkStmt *sql.Stmt
}
func (s *sessionStatements) prepare() (err error) {
if s.insertSessionStmt, err = s.db.Prepare(insertSessionSQL); err != nil {
return
}
if s.selectSessionStmt, err = s.db.Prepare(selectSessionSQL); err != nil {
return
}
if s.selectSessionByThreePidAndCLientSecretStmt, err = s.db.Prepare(selectSessionByThreePidAndCLientSecretSQL); err != nil {
return
}
if s.deleteSessionStmt, err = s.db.Prepare(deleteSessionSQL); err != nil {
return
}
if s.validateSessionStmt, err = s.db.Prepare(validateSessionSQL); err != nil {
return
}
if s.updateSendAttemptNextLinkStmt, err = s.db.Prepare(updateSendAttemptNextLinkSQL); err != nil {
return
}
return
}
func (s *sessionStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(sessionsSchema)
return err
}

View file

@ -2,132 +2,30 @@ package threepid
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/threepid/postgres"
"github.com/matrix-org/dendrite/userapi/storage/threepid/sqlite3"
) )
type Database interface { type Database interface {
InsertSession(context.Context, *api.Session) error InsertSession(ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error)
GetSession(ctx context.Context, sid string) (*api.Session, error) GetSession(ctx context.Context, sid int64) (*api.Session, error)
GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error)
UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error
DeleteSession(ctx context.Context, sid string) error DeleteSession(ctx context.Context, sid int64) error
ValidateSession(ctx context.Context, sid string, validatedAt int) error ValidateSession(ctx context.Context, sid int64, validatedAt int64) error
} }
type Db struct { // Open opens a database connection.
db *sql.DB func Open(dbProperties *config.DatabaseOptions) (Database, error) {
writer sqlutil.Writer
stm *sessionStatements
writeHandler func(func(*sql.Tx) error) error
}
func (d *Db) InsertSession(ctx context.Context, s *api.Session) error {
h := func(_ *sql.Tx) error {
_, err := d.stm.insertSessionStmt.ExecContext(ctx, s.Sid, s.ClientSecret, s.ThreePid, s.Token, s.NextLink, s.SendAttempt, s.ValidatedAt, s.Validated)
return err
}
return d.writeHandler(h)
}
func (d *Db) GetSession(ctx context.Context, sid string) (*api.Session, error) {
s := api.Session{}
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
s.Sid = sid
return &s, err
}
func (d *Db) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
s := api.Session{}
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
QueryRowContext(ctx, threePid, ClientSecret).Scan(
&s.Sid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
s.ThreePid = threePid
s.ClientSecret = ClientSecret
return &s, err
}
func (d *Db) UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error {
h := func(_ *sql.Tx) error {
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
return err
}
return d.writeHandler(h)
}
func (d *Db) DeleteSession(ctx context.Context, sid string) error {
h := func(_ *sql.Tx) error {
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
return err
}
return d.writeHandler(h)
}
func (d *Db) ValidateSession(ctx context.Context, sid string, validatedAt int) error {
h := func(_ *sql.Tx) error {
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
return err
}
return d.writeHandler(h)
}
func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
stmt := sessionStatements{
db: db,
writer: writer,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = stmt.execSchema(db); err != nil {
return nil, err
}
if err = stmt.prepare(); err != nil {
return nil, err
}
handler := func(f func(tx *sql.Tx) error) error {
return writer.Do(nil, nil, f)
}
return &Db{db, writer, &stmt, handler}, nil
}
func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
stmt := sessionStatements{
db: db,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = stmt.execSchema(db); err != nil {
return nil, err
}
if err = stmt.prepare(); err != nil {
return nil, err
}
handler := func(f func(tx *sql.Tx) error) error {
return f(nil)
}
return &Db{db, nil, &stmt, handler}, nil
}
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return newSQLiteDatabase(dbProperties) return sqlite3.Open(dbProperties)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return newPostgresDatabase(dbProperties) return postgres.Open(dbProperties)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }

View file

@ -14,7 +14,6 @@ import (
) )
var testSession = api.Session{ var testSession = api.Session{
// Sid: "123",
ClientSecret: "099azAZ.=_-", ClientSecret: "099azAZ.=_-",
ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com", ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com",
Token: "fooBAR123", Token: "fooBAR123",
@ -26,15 +25,24 @@ var testSession = api.Session{
var testCtx = context.Background() var testCtx = context.Background()
func mustNewDatabaseWithTestSession(is *is.I) *Db { func mustNewDatabaseWithTestSession(is *is.I) Database {
randPostfix := strconv.Itoa(rand.Int()) randPostfix := strconv.Itoa(rand.Int())
dbPath := os.TempDir() + "/dendrite-" + randPostfix dbPath := os.TempDir() + "/dendrite-" + randPostfix
println(dbPath) println(dbPath)
dut, err := newSQLiteDatabase(&config.DatabaseOptions{ dut, err := Open(&config.DatabaseOptions{
ConnectionString: config.DataSource("file:" + dbPath), ConnectionString: config.DataSource("file:" + dbPath),
}) })
is.NoErr(err) is.NoErr(err)
err = dut.InsertSession(testCtx, &testSession) sid, err := dut.InsertSession(
testCtx,
testSession.ClientSecret,
testSession.ThreePid,
testSession.Token,
testSession.NextLink,
testSession.ValidatedAt,
testSession.Validated,
testSession.SendAttempt)
testSession.Sid = sid
is.NoErr(err) is.NoErr(err)
return dut return dut
} }
@ -78,7 +86,7 @@ func TestDeleteSession(t *testing.T) {
func TestValidateSession(t *testing.T) { func TestValidateSession(t *testing.T) {
is := is.New(t) is := is.New(t)
dut := mustNewDatabaseWithTestSession(is) dut := mustNewDatabaseWithTestSession(is)
validatedAt := 1_623_406_296 validatedAt := int64(1_623_406_296)
err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt) err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt)
is.NoErr(err) is.NoErr(err)
session, err := dut.GetSession(testCtx, testSession.Sid) session, err := dut.GetSession(testCtx, testSession.Sid)

View file

@ -43,7 +43,7 @@ func NewInternalAPI(
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db") logrus.WithError(err).Panicf("failed to connect to device db")
} }
threepidDb, err := threepid.NewDatabase(&cfg.ThreepidDatabase) threepidDb, err := threepid.Open(&cfg.ThreepidDatabase)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to threepid db") logrus.WithError(err).Panicf("failed to connect to threepid db")
} }

View file

@ -171,7 +171,7 @@ func TestCreateSession_Twice(t *testing.T) {
resp := api.CreateSessionResponse{} resp := api.CreateSessionResponse{}
err := internalApi.CreateSession(ctx, testReq, &resp) err := internalApi.CreateSession(ctx, testReq, &resp)
is.NoErr(err) is.NoErr(err)
is.Equal(len(resp.Sid), 43) is.Equal(resp.Sid, int64(1))
select { select {
case <-mailer.c[api.AccountPassword]: case <-mailer.c[api.AccountPassword]:
t.Fatal("email was received, but sent attempt was not increased") t.Fatal("email was received, but sent attempt was not increased")
@ -189,7 +189,7 @@ func TestCreateSession_Twice_IncreaseSendAttempt(t *testing.T) {
testReqBumped.SendAttempt = 1 testReqBumped.SendAttempt = 1
err := internalApi.CreateSession(ctx, &testReqBumped, &resp) err := internalApi.CreateSession(ctx, &testReqBumped, &resp)
is.NoErr(err) is.NoErr(err)
is.Equal(len(resp.Sid), 43) is.Equal(resp.Sid, int64(1))
sub := <-mailer.c[api.AccountPassword] sub := <-mailer.c[api.AccountPassword]
is.Equal(len(sub.Token), 64) is.Equal(len(sub.Token), 64)
is.Equal(sub.To, testReq.ThreePid) is.Equal(sub.To, testReq.ThreePid)
@ -235,7 +235,7 @@ func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateS
i.Mail = mailer i.Mail = mailer
err := i.CreateSession(ctx, testReq, resp) err := i.CreateSession(ctx, testReq, resp)
is.NoErr(err) is.NoErr(err)
is.Equal(len(resp.Sid), 43) is.Equal(resp.Sid, int64(1))
sub := <-mailer.c[api.AccountPassword] sub := <-mailer.c[api.AccountPassword]
is.Equal(len(sub.Token), 64) is.Equal(len(sub.Token), 64)
is.Equal(sub.To, testReq.ThreePid) is.Equal(sub.To, testReq.ThreePid)
@ -244,14 +244,14 @@ func mustCreateSession(is *is.I, i *internal.UserInternalAPI) (resp *api.CreateS
is.Equal(submitUrl.Host, "example.com") is.Equal(submitUrl.Host, "example.com")
is.Equal(submitUrl.Path, "/_matrix/client/r0/account/password/email/submitToken") is.Equal(submitUrl.Path, "/_matrix/client/r0/account/password/email/submitToken")
q := submitUrl.Query() q := submitUrl.Query()
is.Equal(len(q["sid"][0]), 43) is.Equal(q["sid"][0], "1")
is.Equal(q["token"][0], sub.Token) is.Equal(q["token"][0], sub.Token)
is.Equal(q["client_secret"][0], "foobar") is.Equal(q["client_secret"][0], "foobar")
token = sub.Token token = sub.Token
return return
} }
func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token, sid string) { func mustValidateSesson(is *is.I, i *internal.UserInternalAPI, secret, token string, sid int64) {
err := i.ValidateSession(ctx, &api.ValidateSessionRequest{ err := i.ValidateSession(ctx, &api.ValidateSessionRequest{
SessionOwnership: api.SessionOwnership{ SessionOwnership: api.SessionOwnership{
Sid: sid, Sid: sid,