diff --git a/userapi/internal/threepid.go b/userapi/internal/threepid.go index a06816739..8511936b4 100644 --- a/userapi/internal/threepid.go +++ b/userapi/internal/threepid.go @@ -15,8 +15,7 @@ import ( ) const ( - sessionIdByteLength = 32 - tokenByteLength = 48 + tokenByteLength = 48 ) var ErrBadSession = errors.New("provided sid, client_secret and token does not point to valid session") @@ -25,9 +24,10 @@ func (a *UserInternalAPI) CreateSession(ctx context.Context, req *api.CreateSess s, err := a.ThreePidDB.GetSessionByThreePidAndSecret(ctx, req.ThreePid, req.ClientSecret) if err != nil { if err == sql.ErrNoRows { - token, errB := internal.GenerateBlob(tokenByteLength) - if errB != nil { - return errB + var token string + token, err = internal.GenerateBlob(tokenByteLength) + if err != nil { + return err } s = &api.Session{ ClientSecret: req.ClientSecret, diff --git a/userapi/storage/threepid/postgres/storage.go b/userapi/storage/threepid/postgres/storage.go index ab2c88c08..2c09afeb2 100644 --- a/userapi/storage/threepid/postgres/storage.go +++ b/userapi/storage/threepid/postgres/storage.go @@ -40,11 +40,11 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { } // Create the tables. - if err := create(db); err != nil { + if err = create(db); err != nil { return nil, err } - - sharedDb, err := shared.NewDatabase(db, sqlutil.NewDummyWriter()) + var sharedDb *shared.Database + sharedDb, err = shared.NewDatabase(db, sqlutil.NewDummyWriter()) if err != nil { return nil, err } diff --git a/userapi/storage/threepid/shared/storage.go b/userapi/storage/threepid/shared/storage.go index 7e3749832..3530b8ecc 100644 --- a/userapi/storage/threepid/shared/storage.go +++ b/userapi/storage/threepid/shared/storage.go @@ -18,12 +18,7 @@ func (d *Database) InsertSession( var sid int64 return sid, d.writer.Do(nil, nil, func(_ *sql.Tx) error { err := d.stm.insertSessionStmt.QueryRowContext(ctx, clientSecret, threepid, token, nextlink, validatedAt, validated, sendAttempt).Scan(&sid) - // _, err := d.stm.insertSessionStmt.ExecContext(ctx, clientSecret, threepid, token, nextlink, sendAttempt, validatedAt, validated) return err - // if err != nil { - // return err - // } - // err = d.stm.selectSidStmt.QueryRowContext(ctx).Scan(&sid) }) } @@ -76,61 +71,3 @@ func NewDatabase(db *sql.DB, writer sqlutil.Writer) (*Database, error) { } return &d, nil } - -// func newSQLiteDatabase(dbProperties *config.DatabaseOptions) (*Db, error) { -// db, err := sqlutil.Open(dbProperties) -// if err != nil { -// return nil, err -// } -// writer := sqlutil.NewExclusiveWriter() -// stmt := sessionStatements{ -// db: db, -// writer: writer, -// } - -// // Create tables before executing migrations so we don't fail if the table is missing, -// // and THEN prepare statements so we don't fail due to referencing new columns -// if err = stmt.execSchema(db); err != nil { -// return nil, err -// } -// if err = stmt.prepare(); err != nil { -// return nil, err -// } -// handler := func(f func(tx *sql.Tx) error) error { -// return writer.Do(nil, nil, f) -// } -// return &Db{db, writer, &stmt, handler}, nil -// } - -// func newPostgresDatabase(dbProperties *config.DatabaseOptions) (*Db, error) { -// db, err := sqlutil.Open(dbProperties) -// if err != nil { -// return nil, err -// } -// stmt := sessionStatements{ -// db: db, -// } -// // Create tables before executing migrations so we don't fail if the table is missing, -// // and THEN prepare statements so we don't fail due to referencing new columns -// if err = stmt.execSchema(db); err != nil { -// return nil, err -// } -// if err = stmt.prepare(); err != nil { -// return nil, err -// } -// handler := func(f func(tx *sql.Tx) error) error { -// return f(nil) -// } -// return &Db{db, nil, &stmt, handler}, nil -// } - -// func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { -// switch { -// case dbProperties.ConnectionString.IsSQLite(): -// return newSQLiteDatabase(dbProperties) -// case dbProperties.ConnectionString.IsPostgres(): -// return newPostgresDatabase(dbProperties) -// default: -// return nil, fmt.Errorf("unexpected database type") -// } -// } diff --git a/userapi/storage/threepid/shared/threepid_sessions_table.go b/userapi/storage/threepid/shared/threepid_sessions_table.go index 8a62982c0..50420935d 100644 --- a/userapi/storage/threepid/shared/threepid_sessions_table.go +++ b/userapi/storage/threepid/shared/threepid_sessions_table.go @@ -11,8 +11,6 @@ const ( "INSERT INTO threepid_sessions (client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt)" + "VALUES ($1, $2, $3, $4, $5, $6, $7)" + "RETURNING sid;" - // selectSidSQL = "" + - // "SELECT last_insert_rowid();" selectSessionSQL = "" + "SELECT client_secret, threepid, token, next_link, validated_at_ts, validated, send_attempt FROM threepid_sessions WHERE sid = $1" selectSessionByThreePidAndCLientSecretSQL = "" + @@ -26,8 +24,7 @@ const ( ) type ThreePidSessionStatements struct { - insertSessionStmt *sql.Stmt - // selectSidStmt *sql.Stmt + insertSessionStmt *sql.Stmt selectSessionStmt *sql.Stmt selectSessionByThreePidAndCLientSecretStmt *sql.Stmt deleteSessionStmt *sql.Stmt @@ -39,7 +36,6 @@ func PrepareThreePidSessionsTable(db *sql.DB) (*ThreePidSessionStatements, error s := ThreePidSessionStatements{} return &s, sqlutil.StatementList{ {&s.insertSessionStmt, insertSessionSQL}, - // {&s.selectSidStmt, selectSidSQL}, {&s.selectSessionStmt, selectSessionSQL}, {&s.selectSessionByThreePidAndCLientSecretStmt, selectSessionByThreePidAndCLientSecretSQL}, {&s.deleteSessionStmt, deleteSessionSQL}, diff --git a/userapi/storage/threepid/sqlite3/storage.go b/userapi/storage/threepid/sqlite3/storage.go index 22b3e36df..d59dcf1d2 100644 --- a/userapi/storage/threepid/sqlite3/storage.go +++ b/userapi/storage/threepid/sqlite3/storage.go @@ -40,11 +40,12 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { } // Create the tables. - if err := create(db); err != nil { + if err = create(db); err != nil { return nil, err } - sharedDb, err := shared.NewDatabase(db, sqlutil.NewExclusiveWriter()) + var sharedDb *shared.Database + sharedDb, err = shared.NewDatabase(db, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err }