mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 01:33:10 -06:00
137 lines
4.3 KiB
Go
137 lines
4.3 KiB
Go
package shared
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
|
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
"github.com/matrix-org/dendrite/userapi/api"
|
|
)
|
|
|
|
type Database struct {
|
|
writer sqlutil.Writer
|
|
stm *ThreePidSessionStatements
|
|
}
|
|
|
|
func (d *Database) InsertSession(
|
|
ctx context.Context, clientSecret, threepid, token, nextlink string, validatedAt int64, validated bool, sendAttempt int) (int64, error) {
|
|
var sid int64
|
|
return sid, d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
|
err := d.stm.insertSessionStmt.QueryRowContext(ctx, clientSecret, threepid, token, nextlink, validatedAt, validated, sendAttempt).Scan(&sid)
|
|
// _, err := d.stm.insertSessionStmt.ExecContext(ctx, clientSecret, threepid, token, nextlink, sendAttempt, validatedAt, validated)
|
|
return err
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// err = d.stm.selectSidStmt.QueryRowContext(ctx).Scan(&sid)
|
|
})
|
|
}
|
|
|
|
func (d *Database) GetSession(ctx context.Context, sid int64) (*api.Session, error) {
|
|
s := api.Session{}
|
|
err := d.stm.selectSessionStmt.QueryRowContext(ctx, sid).Scan(&s.ClientSecret, &s.ThreePid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt)
|
|
s.Sid = sid
|
|
return &s, err
|
|
}
|
|
|
|
func (d *Database) GetSessionByThreePidAndSecret(ctx context.Context, threePid, ClientSecret string) (*api.Session, error) {
|
|
s := api.Session{}
|
|
err := d.stm.selectSessionByThreePidAndCLientSecretStmt.
|
|
QueryRowContext(ctx, threePid, ClientSecret).Scan(
|
|
&s.Sid, &s.Token, &s.NextLink, &s.ValidatedAt, &s.Validated, &s.SendAttempt)
|
|
s.ThreePid = threePid
|
|
s.ClientSecret = ClientSecret
|
|
return &s, err
|
|
}
|
|
|
|
func (d *Database) UpdateSendAttemptNextLink(ctx context.Context, sid int64, nextLink string) error {
|
|
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
|
_, err := d.stm.updateSendAttemptNextLinkStmt.ExecContext(ctx, nextLink, sid)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (d *Database) DeleteSession(ctx context.Context, sid int64) error {
|
|
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
|
_, err := d.stm.deleteSessionStmt.ExecContext(ctx, sid)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (d *Database) ValidateSession(ctx context.Context, sid int64, validatedAt int64) error {
|
|
return d.writer.Do(nil, nil, func(_ *sql.Tx) error {
|
|
_, err := d.stm.validateSessionStmt.ExecContext(ctx, true, validatedAt, sid)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func NewDatabase(db *sql.DB, writer sqlutil.Writer) (*Database, error) {
|
|
threePidSessionsTable, err := PrepareThreePidSessionsTable(db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
d := Database{
|
|
writer: writer,
|
|
stm: threePidSessionsTable,
|
|
}
|
|
return &d, nil
|
|
}
|
|
|
|
// func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
|
// db, err := sqlutil.Open(dbProperties)
|
|
// if err != nil {
|
|
// return nil, err
|
|
// }
|
|
// writer := sqlutil.NewExclusiveWriter()
|
|
// stmt := sessionStatements{
|
|
// db: db,
|
|
// writer: writer,
|
|
// }
|
|
|
|
// // Create tables before executing migrations so we don't fail if the table is missing,
|
|
// // and THEN prepare statements so we don't fail due to referencing new columns
|
|
// if err = stmt.execSchema(db); err != nil {
|
|
// return nil, err
|
|
// }
|
|
// if err = stmt.prepare(); err != nil {
|
|
// return nil, err
|
|
// }
|
|
// handler := func(f func(tx *sql.Tx) error) error {
|
|
// return writer.Do(nil, nil, f)
|
|
// }
|
|
// return &Db{db, writer, &stmt, handler}, nil
|
|
// }
|
|
|
|
// func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) {
|
|
// db, err := sqlutil.Open(dbProperties)
|
|
// if err != nil {
|
|
// return nil, err
|
|
// }
|
|
// stmt := sessionStatements{
|
|
// db: db,
|
|
// }
|
|
// // Create tables before executing migrations so we don't fail if the table is missing,
|
|
// // and THEN prepare statements so we don't fail due to referencing new columns
|
|
// if err = stmt.execSchema(db); err != nil {
|
|
// return nil, err
|
|
// }
|
|
// if err = stmt.prepare(); err != nil {
|
|
// return nil, err
|
|
// }
|
|
// handler := func(f func(tx *sql.Tx) error) error {
|
|
// return f(nil)
|
|
// }
|
|
// return &Db{db, nil, &stmt, handler}, nil
|
|
// }
|
|
|
|
// func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
|
// switch {
|
|
// case dbProperties.ConnectionString.IsSQLite():
|
|
// return newSQLiteDatabase(dbProperties)
|
|
// case dbProperties.ConnectionString.IsPostgres():
|
|
// return newPostgresDatabase(dbProperties)
|
|
// default:
|
|
// return nil, fmt.Errorf("unexpected database type")
|
|
// }
|
|
// }
|