Factor out runTransaction to common code (#162)

This commit is contained in:
Mark Haines 2017-07-17 17:20:57 +01:00 committed by GitHub
parent d3a29b7816
commit b06d1124f7
6 changed files with 58 additions and 69 deletions

View file

@ -24,7 +24,6 @@ import (
"strings" "strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -40,10 +39,16 @@ var UnknownDeviceID = "unknown-device"
// 32 bytes => 256 bits // 32 bytes => 256 bits
var tokenByteLength = 32 var tokenByteLength = 32
// DeviceDatabase represents a device database.
type DeviceDatabase interface {
// Lookup the device matching the given access token.
GetDeviceByAccessToken(token string) (*authtypes.Device, error)
}
// VerifyAccessToken verifies that an access token was supplied in the given HTTP request // VerifyAccessToken verifies that an access token was supplied in the given HTTP request
// and returns the device it corresponds to. Returns resErr (an error response which can be // and returns the device it corresponds to. Returns resErr (an error response which can be
// sent to the client) if the token is invalid or there was a problem querying the database. // sent to the client) if the token is invalid or there was a problem querying the database.
func VerifyAccessToken(req *http.Request, deviceDB *devices.Database) (device *authtypes.Device, resErr *util.JSONResponse) { func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *authtypes.Device, resErr *util.JSONResponse) {
token, err := extractAccessToken(req) token, err := extractAccessToken(req)
if err != nil { if err != nil {
resErr = &util.JSONResponse{ resErr = &util.JSONResponse{

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -53,7 +54,7 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro
// an error will be returned. // an error will be returned.
// Returns the device on success. // Returns the device on success.
func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) { func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) {
returnErr = runTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing token for this device // Revoke existing token for this device
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil { if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil {
@ -74,30 +75,10 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a
// If the device doesn't exist, it will not return an error // If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error // If something went wrong during the deletion, it will return the SQL error
func (d *Database) RemoveDevice(deviceID string, localpart string) error { func (d *Database) RemoveDevice(deviceID string, localpart string) error {
return runTransaction(d.db, func(txn *sql.Tx) error { return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows {
return err return err
} }
return nil return nil
}) })
} }
// TODO: factor out to common
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}

View file

@ -1,16 +1,16 @@
package common package common
import ( import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"net/http"
) )
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request. // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request.
func MakeAuthAPI(metricsName string, deviceDB *devices.Database, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler { func MakeAuthAPI(metricsName string, deviceDB auth.DeviceDatabase, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler {
h := util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { h := util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
device, resErr := auth.VerifyAccessToken(req, deviceDB) device, resErr := auth.VerifyAccessToken(req, deviceDB)
if resErr != nil { if resErr != nil {

View file

@ -0,0 +1,41 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"database/sql"
)
// WithTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolledback
// Otherwise the transaction is committed.
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}

View file

@ -77,7 +77,7 @@ func (d *Database) UpdateRoom(
addHosts []types.JoinedHost, addHosts []types.JoinedHost,
removeHosts []string, removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) { ) (joinedHosts []types.JoinedHost, err error) {
err = runTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err = d.insertRoom(txn, roomID); err != nil { if err = d.insertRoom(txn, roomID); err != nil {
return err return err
} }
@ -105,22 +105,3 @@ func (d *Database) UpdateRoom(
}) })
return return
} }
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}

View file

@ -92,7 +92,7 @@ func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Even
func (d *SyncServerDatabase) WriteEvent( func (d *SyncServerDatabase) WriteEvent(
ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string,
) (streamPos types.StreamPosition, returnErr error) { ) (streamPos types.StreamPosition, returnErr error) {
returnErr = runTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs) pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs)
if err != nil { if err != nil {
@ -162,7 +162,7 @@ func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error)
// IncrementalSync returns all the data needed in order to create an incremental sync response. // IncrementalSync returns all the data needed in order to create an incremental sync response.
func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) { func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) {
returnErr = runTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
// Work out which rooms to return in the response. This is done by getting not only the currently // Work out which rooms to return in the response. This is done by getting not only the currently
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
// This works out what the 'state' key should be for each room as well as which membership block // This works out what the 'state' key should be for each room as well as which membership block
@ -223,7 +223,7 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
// a consistent view of the database throughout. This includes extracting the sync stream position. // a consistent view of the database throughout. This includes extracting the sync stream position.
// This does have the unfortunate side-effect that all the matrixy logic resides in this function, // This does have the unfortunate side-effect that all the matrixy logic resides in this function,
// but it's better to not hide the fact that this is being done in a transaction. // but it's better to not hide the fact that this is being done in a transaction.
returnErr = runTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
// Get the current stream position which we will base the sync response on. // Get the current stream position which we will base the sync response on.
id, err := d.events.selectMaxID(txn) id, err := d.events.selectMaxID(txn)
if err != nil { if err != nil {
@ -479,22 +479,3 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
} }
return "" return ""
} }
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}