mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-28 17:23:09 -06:00
Implement threepid sessions storage
This commit is contained in:
parent
d36116c089
commit
3082a5dee9
1
go.mod
1
go.mod
|
|
@ -27,6 +27,7 @@ require (
|
|||
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
|
||||
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/matryer/is v1.4.0
|
||||
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -711,6 +711,8 @@ github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a/go.mod h1:UQzJ
|
|||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
|
||||
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
|
|
|
|||
|
|
@ -1,8 +0,0 @@
|
|||
package threepid
|
||||
|
||||
type storage interface {
|
||||
InsertSession(*Session) error
|
||||
GetSession(sid string) (*Session, error)
|
||||
GetSessionByThreePidAndSecret(threePid, ClientSecret string) (*Session, error)
|
||||
RemoveSession(sid string) error
|
||||
}
|
||||
81
userapi/storage/threepid/stmt.go
Normal file
81
userapi/storage/threepid/stmt.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
package threepid
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const sessionsSchema = `
|
||||
-- This sequence is used for automatic allocation of session_id.
|
||||
-- CREATE SEQUENCE IF NOT EXISTS threepid_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS threepid_sessions (
|
||||
sid VARCHAR(255) PRIMARY KEY,
|
||||
client_secret VARCHAR(255),
|
||||
threepid TEXT ,
|
||||
token VARCHAR(255) ,
|
||||
next_link TEXT,
|
||||
validated_at_ts BIGINT,
|
||||
validated BOOLEAN,
|
||||
send_attempt INT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS threepid_sessions_threepids
|
||||
ON threepid_sessions (threepid, client_secret)
|
||||
`
|
||||
|
||||
const (
|
||||
insertSessionSQL = "" +
|
||||
"INSERT INTO threepid_sessions (sid, client_secret, threepid, token, next_link, send_attempt, validated_at_ts, validated)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
selectSessionSQL = "" +
|
||||
"SELECT client_secret, threepid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE sid == $1"
|
||||
selectSessionByThreePidAndCLientSecretSQL = "" +
|
||||
"SELECT sid, token, next_link, validated, validated_at_ts, send_attempt FROM threepid_sessions WHERE threepid == $1 AND client_secret == $2"
|
||||
deleteSessionSQL = "" +
|
||||
"DELETE FROM threepid_sessions WHERE sid = $1"
|
||||
validateSessionSQL = "" +
|
||||
"UPDATE threepid_sessions SET validated = $1, validated_at_ts = $2 WHERE sid = $3"
|
||||
updateSendAttemptNextLinkSQL = "" +
|
||||
"UPDATE threepid_sessions SET send_attempt = send_attempt + 1, next_link = $1 WHERE sid = $2"
|
||||
)
|
||||
|
||||
type sessionStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertSessionStmt *sql.Stmt
|
||||
selectSessionStmt *sql.Stmt
|
||||
selectSessionByThreePidAndCLientSecretStmt *sql.Stmt
|
||||
deleteSessionStmt *sql.Stmt
|
||||
validateSessionStmt *sql.Stmt
|
||||
updateSendAttemptNextLinkStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *sessionStatements) prepare() (err error) {
|
||||
if s.insertSessionStmt, err = s.db.Prepare(insertSessionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectSessionStmt, err = s.db.Prepare(selectSessionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectSessionByThreePidAndCLientSecretStmt, err = s.db.Prepare(selectSessionByThreePidAndCLientSecretSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteSessionStmt, err = s.db.Prepare(deleteSessionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.validateSessionStmt, err = s.db.Prepare(validateSessionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateSendAttemptNextLinkStmt, err = s.db.Prepare(updateSendAttemptNextLinkSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *sessionStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(sessionsSchema)
|
||||
return err
|
||||
}
|
||||
134
userapi/storage/threepid/storage.go
Normal file
134
userapi/storage/threepid/storage.go
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
package threepid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
InsertSession(context.Context, *api.Session) error
|
||||
GetSession(ctx context.Context, sid string) (*api.Session, error)
|
||||
GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error)
|
||||
UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error
|
||||
RemoveSession(ctx context.Context, sid string) error
|
||||
ValidateSession(ctx context.Context, sid string, validatedAt int) error
|
||||
}
|
||||
|
||||
type Db struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
stm *sessionStatements
|
||||
writeHandler func(func(*sql.Tx) error) error
|
||||
}
|
||||
|
||||
func (d *Db) InsertSession(ctx context.Context, s *api.Session) error {
|
||||
h := func(_ *sql.Tx) error {
|
||||
_, err := d.stm.insertSessionStmt.ExecContext(ctx, s.Sid, s.ClientSecret, s.ThreePid, s.Token, s.NextLink, s.SendAttempt, s.ValidatedAt, s.Validated)
|
||||
return err
|
||||
}
|
||||
return d.writeHandler(h)
|
||||
}
|
||||
|
||||
func (d *Db) GetSession(ctx context.Context, sid string) (*api.Session, error) {
|
||||
s := api.Session{}
|
||||
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
|
||||
s.Sid = sid
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func (d *Db) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
|
||||
s := api.Session{}
|
||||
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
|
||||
QueryRowContext(ctx, threePid, ClientSecret).Scan(
|
||||
&s.Sid, &s.Token, &s.NextLink, &s.Validated, &s.ValidatedAt, &s.SendAttempt)
|
||||
s.ThreePid = threePid
|
||||
s.ClientSecret = ClientSecret
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func (d *Db) UpdateSendAttemptNextLink(ctx context.Context, sid, nextLink string) error {
|
||||
h := func(_ *sql.Tx) error {
|
||||
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
|
||||
return err
|
||||
}
|
||||
return d.writeHandler(h)
|
||||
}
|
||||
|
||||
func (d *Db) RemoveSession(ctx context.Context, sid string) error {
|
||||
h := func(_ *sql.Tx) error {
|
||||
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
|
||||
return err
|
||||
}
|
||||
return d.writeHandler(h)
|
||||
}
|
||||
|
||||
func (d *Db) ValidateSession(ctx context.Context, sid string, validatedAt int) error {
|
||||
h := func(_ *sql.Tx) error {
|
||||
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
|
||||
return err
|
||||
}
|
||||
return d.writeHandler(h)
|
||||
}
|
||||
|
||||
func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer := sqlutil.NewExclusiveWriter()
|
||||
stmt := sessionStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = stmt.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = stmt.prepare(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler := func(f func(tx *sql.Tx) error) error {
|
||||
return writer.Do(nil, nil, f)
|
||||
}
|
||||
return &Db{db, writer, &stmt, handler}, nil
|
||||
}
|
||||
|
||||
func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sessionStatements{
|
||||
db: db,
|
||||
}
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = stmt.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = stmt.prepare(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler := func(f func(tx *sql.Tx) error) error {
|
||||
return f(nil)
|
||||
}
|
||||
return &Db{db, nil, &stmt, handler}, nil
|
||||
}
|
||||
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return newSQLiteDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return newPostgresDatabase(dbProperties)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
88
userapi/storage/threepid/storage_test.go
Normal file
88
userapi/storage/threepid/storage_test.go
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
package threepid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matryer/is"
|
||||
)
|
||||
|
||||
var testSession = api.Session{
|
||||
// Sid: "123",
|
||||
ClientSecret: "099azAZ.=_-",
|
||||
ThreePid: "azAZ09!#$%&'*+-/=?^_`{|}~@bar09.com",
|
||||
Token: "fooBAR123",
|
||||
NextLink: "https://example.com?user=foo",
|
||||
SendAttempt: 0,
|
||||
Validated: true,
|
||||
ValidatedAt: 0,
|
||||
}
|
||||
|
||||
var testCtx = context.Background()
|
||||
|
||||
func mustNewDatabaseWithTestSession(is *is.I) *Db {
|
||||
randPostfix := strconv.Itoa(rand.Int())
|
||||
dbPath := os.TempDir() + "/dendrite-" + randPostfix
|
||||
println(dbPath)
|
||||
dut, err := newSQLiteDatabase(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource("file:" + dbPath),
|
||||
})
|
||||
is.NoErr(err)
|
||||
err = dut.InsertSession(testCtx, &testSession)
|
||||
is.NoErr(err)
|
||||
return dut
|
||||
}
|
||||
func TestGetSession(t *testing.T) {
|
||||
is := is.New(t)
|
||||
dut := mustNewDatabaseWithTestSession(is)
|
||||
s, err := dut.GetSession(testCtx, testSession.Sid)
|
||||
is.NoErr(err)
|
||||
is.Equal(*s, testSession)
|
||||
}
|
||||
|
||||
func TestGetSessionByThreePidAndSecret(t *testing.T) {
|
||||
is := is.New(t)
|
||||
dut := mustNewDatabaseWithTestSession(is)
|
||||
s, err := dut.GetSessionByThreePidAndSecret(testCtx, testSession.ThreePid, testSession.ClientSecret)
|
||||
is.NoErr(err)
|
||||
is.Equal(*s, testSession)
|
||||
}
|
||||
|
||||
func TestBumpSendAttempt(t *testing.T) {
|
||||
is := is.New(t)
|
||||
dut := mustNewDatabaseWithTestSession(is)
|
||||
nextLink := "https://foo.bar"
|
||||
err := dut.UpdateSendAttemptNextLink(testCtx, testSession.Sid, nextLink)
|
||||
is.NoErr(err)
|
||||
s, err := dut.GetSession(testCtx, testSession.Sid)
|
||||
is.NoErr(err)
|
||||
is.Equal(s.SendAttempt, 1)
|
||||
is.Equal(s.NextLink, nextLink)
|
||||
}
|
||||
|
||||
func TestDeleteSession(t *testing.T) {
|
||||
is := is.New(t)
|
||||
dut := mustNewDatabaseWithTestSession(is)
|
||||
err := dut.RemoveSession(testCtx, testSession.Sid)
|
||||
is.NoErr(err)
|
||||
_, err = dut.GetSession(testCtx, testSession.Sid)
|
||||
is.Equal(err, sql.ErrNoRows)
|
||||
}
|
||||
|
||||
func TestValidateSession(t *testing.T) {
|
||||
is := is.New(t)
|
||||
dut := mustNewDatabaseWithTestSession(is)
|
||||
validatedAt := 1_623_406_296
|
||||
err := dut.ValidateSession(testCtx, testSession.Sid, validatedAt)
|
||||
is.NoErr(err)
|
||||
session, err := dut.GetSession(testCtx, testSession.Sid)
|
||||
is.NoErr(err)
|
||||
is.Equal(session.Validated, true)
|
||||
is.Equal(session.ValidatedAt, validatedAt)
|
||||
}
|
||||
Loading…
Reference in a new issue