mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 01:33:10 -06:00
Implement threepid sessions storage
This commit is contained in:
parent
d36116c089
commit
3082a5dee9
1
go.mod
1
go.mod
|
|
@ -27,6 +27,7 @@ require (
|
||||||
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
|
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
|
||||||
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a
|
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
|
github.com/matryer/is v1.4.0
|
||||||
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
|
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||||
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
|
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -711,6 +711,8 @@ github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a/go.mod h1:UQzJ
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||||
|
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
|
||||||
|
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||||
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||||
|
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
package threepid
|
|
||||||
|
|
||||||
type storage interface {
|
|
||||||
InsertSession(*Session) error
|
|
||||||
GetSession(sid string) (*Session, error)
|
|
||||||
GetSessionByThreePidAndSecret(threePid, ClientSecret string) (*Session, error)
|
|
||||||
RemoveSession(sid string) error
|
|
||||||
}
|
|
||||||
81
userapi/storage/threepid/stmt.go
Normal file
81
userapi/storage/threepid/stmt.go
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
package threepid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionsSchema = `
|
||||||
|
-- This sequence is used for automatic allocation of session_id.
|
||||||
|
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
|
||||||
|
|
||||||
|
-- Stores data about devices.
|
||||||
|
CREATE TABLE IF NOT EXISTS threepid_sessions (
|
||||||
|
sid VARCHAR(255) PRIMARY KEY,
|
||||||
|
client_secret VARCHAR(255),
|
||||||
|
threepid TEXT ,
|
||||||
|
token VARCHAR(255) ,
|
||||||
|
next_link TEXT,
|
||||||
|
validated_at_ts BIGINT,
|
||||||
|
validated BOOLEAN,
|
||||||
|
send_attempt INT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
|
||||||
|
ON threepid_sessions (threepid, client_secret)
|
||||||
|
`
|
||||||
|
|
||||||
|
const (
|
||||||
|
insertSessionSQL = "" +
|
||||||
|
"INSERT INTO threepid_sessions (sid, client_secret, threepid, token, next_link, send_attempt, validated_at_ts, validated)" +
|
||||||
|
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
|
selectSessionSQL = "" +
|
||||||
|
"SELECT client_secret, threepid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE sid == $1"
|
||||||
|
selectSessionByThreePidAndCLientSecretSQL = "" +
|
||||||
|
"SELECT sid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE threepid == $1 AND client_secret == $2"
|
||||||
|
deleteSessionSQL = "" +
|
||||||
|
"DELETE FROM threepid_sessions WHERE sid = $1"
|
||||||
|
validateSessionSQL = "" +
|
||||||
|
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
|
||||||
|
updateSendAttemptNextLinkSQL = "" +
|
||||||
|
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sessionStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
|
insertSessionStmt *sql.Stmt
|
||||||
|
selectSessionStmt *sql.Stmt
|
||||||
|
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
|
||||||
|
deleteSessionStmt *sql.Stmt
|
||||||
|
validateSessionStmt *sql.Stmt
|
||||||
|
updateSendAttemptNextLinkStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionStatements) prepare() (err error) {
|
||||||
|
if s.insertSessionStmt, err = s.db.Prepare(insertSessionSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectSessionStmt, err = s.db.Prepare(selectSessionSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectSessionByThreePidAndCLientSecretStmt, err = s.db.Prepare(selectSessionByThreePidAndCLientSecretSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteSessionStmt, err = s.db.Prepare(deleteSessionSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.validateSessionStmt, err = s.db.Prepare(validateSessionSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.updateSendAttemptNextLinkStmt, err = s.db.Prepare(updateSendAttemptNextLinkSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionStatements) execSchema(db *sql.DB) error {
|
||||||
|
_, err := db.Exec(sessionsSchema)
|
||||||
|
return err
|
||||||
|
}
|
||||||
134
userapi/storage/threepid/storage.go
Normal file
134
userapi/storage/threepid/storage.go
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
package threepid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Database interface {
|
||||||
|
InsertSession(context.Context, *api.Session) error
|
||||||
|
GetSession(ctx context.Context, sid string) (*api.Session, error)
|
||||||
|
GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error)
|
||||||
|
UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error
|
||||||
|
RemoveSession(ctx context.Context, sid string) error
|
||||||
|
ValidateSession(ctx context.Context, sid string, validatedAt int) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Db struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
|
stm *sessionStatements
|
||||||
|
writeHandler func(func(*sql.Tx) error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Db) InsertSession(ctx context.Context, s *api.Session) error {
|
||||||
|
h := func(_ *sql.Tx) error {
|
||||||
|
_, err := d.stm.insertSessionStmt.ExecContext(ctx, s.Sid, s.ClientSecret, s.ThreePid, s.Token, s.NextLink, s.SendAttempt, s.ValidatedAt, s.Validated)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return d.writeHandler(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Db) GetSession(ctx context.Context, sid string) (*api.Session, error) {
|
||||||
|
s := api.Session{}
|
||||||
|
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
|
||||||
|
s.Sid = sid
|
||||||
|
return &s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Db) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
|
||||||
|
s := api.Session{}
|
||||||
|
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
|
||||||
|
QueryRowContext(ctx, threePid, ClientSecret).Scan(
|
||||||
|
&s.Sid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
|
||||||
|
s.ThreePid = threePid
|
||||||
|
s.ClientSecret = ClientSecret
|
||||||
|
return &s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Db) UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error {
|
||||||
|
h := func(_ *sql.Tx) error {
|
||||||
|
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return d.writeHandler(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Db) RemoveSession(ctx context.Context, sid string) error {
|
||||||
|
h := func(_ *sql.Tx) error {
|
||||||
|
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return d.writeHandler(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Db) ValidateSession(ctx context.Context, sid string, validatedAt int) error {
|
||||||
|
h := func(_ *sql.Tx) error {
|
||||||
|
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return d.writeHandler(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||||
|
db, err := sqlutil.Open(dbProperties)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
writer := sqlutil.NewExclusiveWriter()
|
||||||
|
stmt := sessionStatements{
|
||||||
|
db: db,
|
||||||
|
writer: writer,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||||
|
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||||
|
if err = stmt.execSchema(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = stmt.prepare(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
handler := func(f func(tx *sql.Tx) error) error {
|
||||||
|
return writer.Do(nil, nil, f)
|
||||||
|
}
|
||||||
|
return &Db{db, writer, &stmt, handler}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||||
|
db, err := sqlutil.Open(dbProperties)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stmt := sessionStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||||
|
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||||
|
if err = stmt.execSchema(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = stmt.prepare(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
handler := func(f func(tx *sql.Tx) error) error {
|
||||||
|
return f(nil)
|
||||||
|
}
|
||||||
|
return &Db{db, nil, &stmt, handler}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||||
|
switch {
|
||||||
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
|
return newSQLiteDatabase(dbProperties)
|
||||||
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
|
return newPostgresDatabase(dbProperties)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected database type")
|
||||||
|
}
|
||||||
|
}
|
||||||
88
userapi/storage/threepid/storage_test.go
Normal file
88
userapi/storage/threepid/storage_test.go
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
package threepid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matryer/is"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testSession = api.Session{
|
||||||
|
// Sid: "123",
|
||||||
|
ClientSecret: "099azAZ.=_-",
|
||||||
|
ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com",
|
||||||
|
Token: "fooBAR123",
|
||||||
|
NextLink: "https://example.com?user=foo",
|
||||||
|
SendAttempt: 0,
|
||||||
|
Validated: true,
|
||||||
|
ValidatedAt: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
var testCtx = context.Background()
|
||||||
|
|
||||||
|
func mustNewDatabaseWithTestSession(is *is.I) *Db {
|
||||||
|
randPostfix := strconv.Itoa(rand.Int())
|
||||||
|
dbPath := os.TempDir() + "/dendrite-" + randPostfix
|
||||||
|
println(dbPath)
|
||||||
|
dut, err := newSQLiteDatabase(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource("file:" + dbPath),
|
||||||
|
})
|
||||||
|
is.NoErr(err)
|
||||||
|
err = dut.InsertSession(testCtx, &testSession)
|
||||||
|
is.NoErr(err)
|
||||||
|
return dut
|
||||||
|
}
|
||||||
|
func TestGetSession(t *testing.T) {
|
||||||
|
is := is.New(t)
|
||||||
|
dut := mustNewDatabaseWithTestSession(is)
|
||||||
|
s, err := dut.GetSession(testCtx, testSession.Sid)
|
||||||
|
is.NoErr(err)
|
||||||
|
is.Equal(*s, testSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSessionByThreePidAndSecret(t *testing.T) {
|
||||||
|
is := is.New(t)
|
||||||
|
dut := mustNewDatabaseWithTestSession(is)
|
||||||
|
s, err := dut.GetSessionByThreePidAndSecret(testCtx, testSession.ThreePid, testSession.ClientSecret)
|
||||||
|
is.NoErr(err)
|
||||||
|
is.Equal(*s, testSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBumpSendAttempt(t *testing.T) {
|
||||||
|
is := is.New(t)
|
||||||
|
dut := mustNewDatabaseWithTestSession(is)
|
||||||
|
nextLink := "https://foo.bar"
|
||||||
|
err := dut.UpdateSendAttemptNextLink(testCtx, testSession.Sid, nextLink)
|
||||||
|
is.NoErr(err)
|
||||||
|
s, err := dut.GetSession(testCtx, testSession.Sid)
|
||||||
|
is.NoErr(err)
|
||||||
|
is.Equal(s.SendAttempt, 1)
|
||||||
|
is.Equal(s.NextLink, nextLink)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteSession(t *testing.T) {
|
||||||
|
is := is.New(t)
|
||||||
|
dut := mustNewDatabaseWithTestSession(is)
|
||||||
|
err := dut.RemoveSession(testCtx, testSession.Sid)
|
||||||
|
is.NoErr(err)
|
||||||
|
_, err = dut.GetSession(testCtx, testSession.Sid)
|
||||||
|
is.Equal(err, sql.ErrNoRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession(t *testing.T) {
|
||||||
|
is := is.New(t)
|
||||||
|
dut := mustNewDatabaseWithTestSession(is)
|
||||||
|
validatedAt := 1_623_406_296
|
||||||
|
err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt)
|
||||||
|
is.NoErr(err)
|
||||||
|
session, err := dut.GetSession(testCtx, testSession.Sid)
|
||||||
|
is.NoErr(err)
|
||||||
|
is.Equal(session.Validated, true)
|
||||||
|
is.Equal(session.ValidatedAt, validatedAt)
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue