Implement threepid sessions storage

This commit is contained in:
Piotr Kozimor 2021-06-11 12:40:58 +02:00
parent d36116c089
commit 3082a5dee9
6 changed files with 306 additions and 8 deletions

1
go.mod
View file

@ -27,6 +27,7 @@ require (
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/matryer/is v1.4.0
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6

2
go.sum
View file

@ -711,6 +711,8 @@ github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a/go.mod h1:UQzJ
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=

View file

@ -1,8 +0,0 @@
package threepid
type storage interface {
InsertSession(*Session) error
GetSession(sid string) (*Session, error)
GetSessionByThreePidAndSecret(threePid, ClientSecret string) (*Session, error)
RemoveSession(sid string) error
}

View file

@ -0,0 +1,81 @@
package threepid
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const sessionsSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS threepid_sessions (
sid VARCHAR(255) PRIMARY KEY,
client_secret VARCHAR(255),
threepid TEXT ,
token VARCHAR(255) ,
next_link TEXT,
validated_at_ts BIGINT,
validated BOOLEAN,
send_attempt INT
);
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
ON threepid_sessions (threepid, client_secret)
`
const (
insertSessionSQL = "" +
"INSERT INTO threepid_sessions (sid, client_secret, threepid, token, next_link, send_attempt, validated_at_ts, validated)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
selectSessionSQL = "" +
"SELECT client_secret, threepid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE sid == $1"
selectSessionByThreePidAndCLientSecretSQL = "" +
"SELECT sid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE threepid == $1 AND client_secret == $2"
deleteSessionSQL = "" +
"DELETE FROM threepid_sessions WHERE sid = $1"
validateSessionSQL = "" +
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
updateSendAttemptNextLinkSQL = "" +
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
)
type sessionStatements struct {
db *sql.DB
writer sqlutil.Writer
insertSessionStmt *sql.Stmt
selectSessionStmt *sql.Stmt
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
deleteSessionStmt *sql.Stmt
validateSessionStmt *sql.Stmt
updateSendAttemptNextLinkStmt *sql.Stmt
}
func (s *sessionStatements) prepare() (err error) {
if s.insertSessionStmt, err = s.db.Prepare(insertSessionSQL); err != nil {
return
}
if s.selectSessionStmt, err = s.db.Prepare(selectSessionSQL); err != nil {
return
}
if s.selectSessionByThreePidAndCLientSecretStmt, err = s.db.Prepare(selectSessionByThreePidAndCLientSecretSQL); err != nil {
return
}
if s.deleteSessionStmt, err = s.db.Prepare(deleteSessionSQL); err != nil {
return
}
if s.validateSessionStmt, err = s.db.Prepare(validateSessionSQL); err != nil {
return
}
if s.updateSendAttemptNextLinkStmt, err = s.db.Prepare(updateSendAttemptNextLinkSQL); err != nil {
return
}
return
}
func (s *sessionStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(sessionsSchema)
return err
}

View file

@ -0,0 +1,134 @@
package threepid
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
InsertSession(context.Context, *api.Session) error
GetSession(ctx context.Context, sid string) (*api.Session, error)
GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error)
UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error
RemoveSession(ctx context.Context, sid string) error
ValidateSession(ctx context.Context, sid string, validatedAt int) error
}
type Db struct {
db *sql.DB
writer sqlutil.Writer
stm *sessionStatements
writeHandler func(func(*sql.Tx) error) error
}
func (d *Db) InsertSession(ctx context.Context, s *api.Session) error {
h := func(_ *sql.Tx) error {
_, err := d.stm.insertSessionStmt.ExecContext(ctx, s.Sid, s.ClientSecret, s.ThreePid, s.Token, s.NextLink, s.SendAttempt, s.ValidatedAt, s.Validated)
return err
}
return d.writeHandler(h)
}
func (d *Db) GetSession(ctx context.Context, sid string) (*api.Session, error) {
s := api.Session{}
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
s.Sid = sid
return &s, err
}
func (d *Db) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
s := api.Session{}
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
QueryRowContext(ctx, threePid, ClientSecret).Scan(
&s.Sid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
s.ThreePid = threePid
s.ClientSecret = ClientSecret
return &s, err
}
func (d *Db) UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error {
h := func(_ *sql.Tx) error {
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
return err
}
return d.writeHandler(h)
}
func (d *Db) RemoveSession(ctx context.Context, sid string) error {
h := func(_ *sql.Tx) error {
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
return err
}
return d.writeHandler(h)
}
func (d *Db) ValidateSession(ctx context.Context, sid string, validatedAt int) error {
h := func(_ *sql.Tx) error {
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
return err
}
return d.writeHandler(h)
}
func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
stmt := sessionStatements{
db: db,
writer: writer,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = stmt.execSchema(db); err != nil {
return nil, err
}
if err = stmt.prepare(); err != nil {
return nil, err
}
handler := func(f func(tx *sql.Tx) error) error {
return writer.Do(nil, nil, f)
}
return &Db{db, writer, &stmt, handler}, nil
}
func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
stmt := sessionStatements{
db: db,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = stmt.execSchema(db); err != nil {
return nil, err
}
if err = stmt.prepare(); err != nil {
return nil, err
}
handler := func(f func(tx *sql.Tx) error) error {
return f(nil)
}
return &Db{db, nil, &stmt, handler}, nil
}
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return newSQLiteDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return newPostgresDatabase(dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -0,0 +1,88 @@
package threepid
import (
"context"
"database/sql"
"math/rand"
"os"
"strconv"
"testing"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matryer/is"
)
var testSession = api.Session{
// Sid: "123",
ClientSecret: "099azAZ.=_-",
ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com",
Token: "fooBAR123",
NextLink: "https://example.com?user=foo",
SendAttempt: 0,
Validated: true,
ValidatedAt: 0,
}
var testCtx = context.Background()
func mustNewDatabaseWithTestSession(is *is.I) *Db {
randPostfix := strconv.Itoa(rand.Int())
dbPath := os.TempDir() + "/dendrite-" + randPostfix
println(dbPath)
dut, err := newSQLiteDatabase(&config.DatabaseOptions{
ConnectionString: config.DataSource("file:" + dbPath),
})
is.NoErr(err)
err = dut.InsertSession(testCtx, &testSession)
is.NoErr(err)
return dut
}
func TestGetSession(t *testing.T) {
is := is.New(t)
dut := mustNewDatabaseWithTestSession(is)
s, err := dut.GetSession(testCtx, testSession.Sid)
is.NoErr(err)
is.Equal(*s, testSession)
}
func TestGetSessionByThreePidAndSecret(t *testing.T) {
is := is.New(t)
dut := mustNewDatabaseWithTestSession(is)
s, err := dut.GetSessionByThreePidAndSecret(testCtx, testSession.ThreePid, testSession.ClientSecret)
is.NoErr(err)
is.Equal(*s, testSession)
}
func TestBumpSendAttempt(t *testing.T) {
is := is.New(t)
dut := mustNewDatabaseWithTestSession(is)
nextLink := "https://foo.bar"
err := dut.UpdateSendAttemptNextLink(testCtx, testSession.Sid, nextLink)
is.NoErr(err)
s, err := dut.GetSession(testCtx, testSession.Sid)
is.NoErr(err)
is.Equal(s.SendAttempt, 1)
is.Equal(s.NextLink, nextLink)
}
func TestDeleteSession(t *testing.T) {
is := is.New(t)
dut := mustNewDatabaseWithTestSession(is)
err := dut.RemoveSession(testCtx, testSession.Sid)
is.NoErr(err)
_, err = dut.GetSession(testCtx, testSession.Sid)
is.Equal(err, sql.ErrNoRows)
}
func TestValidateSession(t *testing.T) {
is := is.New(t)
dut := mustNewDatabaseWithTestSession(is)
validatedAt := 1_623_406_296
err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt)
is.NoErr(err)
session, err := dut.GetSession(testCtx, testSession.Sid)
is.NoErr(err)
is.Equal(session.Validated, true)
is.Equal(session.ValidatedAt, validatedAt)
}