From 3082a5dee95b239645d7e0302cdd75f6a0e793a5 Mon Sep 17 00:00:00 2001 From: Piotr Kozimor Date: Fri, 11 Jun 2021 12:40:58 +0200 Subject: [PATCH] Implement threepid sessions storage --- go.mod | 1 + go.sum | 2 + threepid/storage.go | 8 -- userapi/storage/threepid/stmt.go | 81 ++++++++++++++ userapi/storage/threepid/storage.go | 134 +++++++++++++++++++++++ userapi/storage/threepid/storage_test.go | 88 +++++++++++++++ 6 files changed, 306 insertions(+), 8 deletions(-) delete mode 100644 threepid/storage.go create mode 100644 userapi/storage/threepid/stmt.go create mode 100644 userapi/storage/threepid/storage.go create mode 100644 userapi/storage/threepid/storage_test.go diff --git a/go.mod b/go.mod index 7273da647..eff077fdf 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161 github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 + github.com/matryer/is v1.4.0 github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 diff --git a/go.sum b/go.sum index eba9a60b1..fb6c93735 100644 --- a/go.sum +++ b/go.sum @@ -711,6 +711,8 @@ github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a/go.mod h1:UQzJ github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= +github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= diff --git a/threepid/storage.go b/threepid/storage.go deleted file mode 100644 index 8da3b2666..000000000 --- a/threepid/storage.go +++ /dev/null @@ -1,8 +0,0 @@ -package threepid - -type storage interface { - InsertSession(*Session) error - GetSession(sid string) (*Session, error) - GetSessionByThreePidAndSecret(threePid, ClientSecret string) (*Session, error) - RemoveSession(sid string) error -} diff --git a/userapi/storage/threepid/stmt.go b/userapi/storage/threepid/stmt.go new file mode 100644 index 000000000..4a660fcb6 --- /dev/null +++ b/userapi/storage/threepid/stmt.go @@ -0,0 +1,81 @@ +package threepid + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const sessionsSchema = ` +-- This sequence is used for automatic allocation of session_id. +-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1; + +-- Stores data about devices. +CREATE TABLE IF NOT EXISTS threepid_sessions ( + sid VARCHAR(255) PRIMARY KEY, + client_secret VARCHAR(255), + threepid TEXT , + token VARCHAR(255) , + next_link TEXT, + validated_at_ts BIGINT, + validated BOOLEAN, + send_attempt INT +); + +CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids + ON threepid_sessions (threepid, client_secret) +` + +const ( + insertSessionSQL = "" + + "INSERT INTO threepid_sessions (sid, client_secret, threepid, token, next_link, send_attempt, validated_at_ts, validated)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + selectSessionSQL = "" + + "SELECT client_secret, threepid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE sid == $1" + selectSessionByThreePidAndCLientSecretSQL = "" + + "SELECT sid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE threepid == $1 AND client_secret == $2" + deleteSessionSQL = "" + + "DELETE FROM threepid_sessions WHERE sid = $1" + validateSessionSQL = "" + + "UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3" + updateSendAttemptNextLinkSQL = "" + + "UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2" +) + +type sessionStatements struct { + db *sql.DB + writer sqlutil.Writer + insertSessionStmt *sql.Stmt + selectSessionStmt *sql.Stmt + selectSessionByThreePidAndCLientSecretStmt *sql.Stmt + deleteSessionStmt *sql.Stmt + validateSessionStmt *sql.Stmt + updateSendAttemptNextLinkStmt *sql.Stmt +} + +func (s *sessionStatements) prepare() (err error) { + if s.insertSessionStmt, err = s.db.Prepare(insertSessionSQL); err != nil { + return + } + if s.selectSessionStmt, err = s.db.Prepare(selectSessionSQL); err != nil { + return + } + if s.selectSessionByThreePidAndCLientSecretStmt, err = s.db.Prepare(selectSessionByThreePidAndCLientSecretSQL); err != nil { + return + } + if s.deleteSessionStmt, err = s.db.Prepare(deleteSessionSQL); err != nil { + return + } + if s.validateSessionStmt, err = s.db.Prepare(validateSessionSQL); err != nil { + return + } + if s.updateSendAttemptNextLinkStmt, err = s.db.Prepare(updateSendAttemptNextLinkSQL); err != nil { + return + } + return +} + +func (s *sessionStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(sessionsSchema) + return err +} diff --git a/userapi/storage/threepid/storage.go b/userapi/storage/threepid/storage.go new file mode 100644 index 000000000..bb59030d5 --- /dev/null +++ b/userapi/storage/threepid/storage.go @@ -0,0 +1,134 @@ +package threepid + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" +) + +type Database interface { + InsertSession(context.Context, *api.Session) error + GetSession(ctx context.Context, sid string) (*api.Session, error) + GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) + UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error + RemoveSession(ctx context.Context, sid string) error + ValidateSession(ctx context.Context, sid string, validatedAt int) error +} + +type Db struct { + db *sql.DB + writer sqlutil.Writer + stm *sessionStatements + writeHandler func(func(*sql.Tx) error) error +} + +func (d *Db) InsertSession(ctx context.Context, s *api.Session) error { + h := func(_ *sql.Tx) error { + _, err := d.stm.insertSessionStmt.ExecContext(ctx, s.Sid, s.ClientSecret, s.ThreePid, s.Token, s.NextLink, s.SendAttempt, s.ValidatedAt, s.Validated) + return err + } + return d.writeHandler(h) +} + +func (d *Db) GetSession(ctx context.Context, sid string) (*api.Session, error) { + s := api.Session{} + err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt) + s.Sid = sid + return &s, err +} + +func (d *Db) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) { + s := api.Session{} + err := d.stm.selectSessionByThreePidAndCLientSecretStmt. + QueryRowContext(ctx, threePid, ClientSecret).Scan( + &s.Sid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt) + s.ThreePid = threePid + s.ClientSecret = ClientSecret + return &s, err +} + +func (d *Db) UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error { + h := func(_ *sql.Tx) error { + _, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid) + return err + } + return d.writeHandler(h) +} + +func (d *Db) RemoveSession(ctx context.Context, sid string) error { + h := func(_ *sql.Tx) error { + _, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid) + return err + } + return d.writeHandler(h) +} + +func (d *Db) ValidateSession(ctx context.Context, sid string, validatedAt int) error { + h := func(_ *sql.Tx) error { + _, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid) + return err + } + return d.writeHandler(h) +} + +func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + writer := sqlutil.NewExclusiveWriter() + stmt := sessionStatements{ + db: db, + writer: writer, + } + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + if err = stmt.execSchema(db); err != nil { + return nil, err + } + if err = stmt.prepare(); err != nil { + return nil, err + } + handler := func(f func(tx *sql.Tx) error) error { + return writer.Do(nil, nil, f) + } + return &Db{db, writer, &stmt, handler}, nil +} + +func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + stmt := sessionStatements{ + db: db, + } + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + if err = stmt.execSchema(db); err != nil { + return nil, err + } + if err = stmt.prepare(); err != nil { + return nil, err + } + handler := func(f func(tx *sql.Tx) error) error { + return f(nil) + } + return &Db{db, nil, &stmt, handler}, nil +} + +func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return newSQLiteDatabase(dbProperties) + case dbProperties.ConnectionString.IsPostgres(): + return newPostgresDatabase(dbProperties) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/userapi/storage/threepid/storage_test.go b/userapi/storage/threepid/storage_test.go new file mode 100644 index 000000000..71036fd7c --- /dev/null +++ b/userapi/storage/threepid/storage_test.go @@ -0,0 +1,88 @@ +package threepid + +import ( + "context" + "database/sql" + "math/rand" + "os" + "strconv" + "testing" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matryer/is" +) + +var testSession = api.Session{ + // Sid: "123", + ClientSecret: "099azAZ.=_-", + ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com", + Token: "fooBAR123", + NextLink: "https://example.com?user=foo", + SendAttempt: 0, + Validated: true, + ValidatedAt: 0, +} + +var testCtx = context.Background() + +func mustNewDatabaseWithTestSession(is *is.I) *Db { + randPostfix := strconv.Itoa(rand.Int()) + dbPath := os.TempDir() + "/dendrite-" + randPostfix + println(dbPath) + dut, err := newSQLiteDatabase(&config.DatabaseOptions{ + ConnectionString: config.DataSource("file:" + dbPath), + }) + is.NoErr(err) + err = dut.InsertSession(testCtx, &testSession) + is.NoErr(err) + return dut +} +func TestGetSession(t *testing.T) { + is := is.New(t) + dut := mustNewDatabaseWithTestSession(is) + s, err := dut.GetSession(testCtx, testSession.Sid) + is.NoErr(err) + is.Equal(*s, testSession) +} + +func TestGetSessionByThreePidAndSecret(t *testing.T) { + is := is.New(t) + dut := mustNewDatabaseWithTestSession(is) + s, err := dut.GetSessionByThreePidAndSecret(testCtx, testSession.ThreePid, testSession.ClientSecret) + is.NoErr(err) + is.Equal(*s, testSession) +} + +func TestBumpSendAttempt(t *testing.T) { + is := is.New(t) + dut := mustNewDatabaseWithTestSession(is) + nextLink := "https://foo.bar" + err := dut.UpdateSendAttemptNextLink(testCtx, testSession.Sid, nextLink) + is.NoErr(err) + s, err := dut.GetSession(testCtx, testSession.Sid) + is.NoErr(err) + is.Equal(s.SendAttempt, 1) + is.Equal(s.NextLink, nextLink) +} + +func TestDeleteSession(t *testing.T) { + is := is.New(t) + dut := mustNewDatabaseWithTestSession(is) + err := dut.RemoveSession(testCtx, testSession.Sid) + is.NoErr(err) + _, err = dut.GetSession(testCtx, testSession.Sid) + is.Equal(err, sql.ErrNoRows) +} + +func TestValidateSession(t *testing.T) { + is := is.New(t) + dut := mustNewDatabaseWithTestSession(is) + validatedAt := 1_623_406_296 + err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt) + is.NoErr(err) + session, err := dut.GetSession(testCtx, testSession.Sid) + is.NoErr(err) + is.Equal(session.Validated, true) + is.Equal(session.ValidatedAt, validatedAt) +}