mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
Factor out runTransaction to common code (#162)
This commit is contained in:
parent
d3a29b7816
commit
b06d1124f7
|
@ -24,7 +24,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
@ -40,10 +39,16 @@ var UnknownDeviceID = "unknown-device"
|
|||
// 32 bytes => 256 bits
|
||||
var tokenByteLength = 32
|
||||
|
||||
// DeviceDatabase represents a device database.
|
||||
type DeviceDatabase interface {
|
||||
// Lookup the device matching the given access token.
|
||||
GetDeviceByAccessToken(token string) (*authtypes.Device, error)
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies that an access token was supplied in the given HTTP request
|
||||
// and returns the device it corresponds to. Returns resErr (an error response which can be
|
||||
// sent to the client) if the token is invalid or there was a problem querying the database.
|
||||
func VerifyAccessToken(req *http.Request, deviceDB *devices.Database) (device *authtypes.Device, resErr *util.JSONResponse) {
|
||||
func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *authtypes.Device, resErr *util.JSONResponse) {
|
||||
token, err := extractAccessToken(req)
|
||||
if err != nil {
|
||||
resErr = &util.JSONResponse{
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
@ -53,7 +54,7 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro
|
|||
// an error will be returned.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) {
|
||||
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing token for this device
|
||||
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil {
|
||||
|
@ -74,30 +75,10 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a
|
|||
// If the device doesn't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error
|
||||
func (d *Database) RemoveDevice(deviceID string, localpart string) error {
|
||||
return runTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: factor out to common
|
||||
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
txn.Rollback()
|
||||
panic(r)
|
||||
} else if err != nil {
|
||||
txn.Rollback()
|
||||
} else {
|
||||
err = txn.Commit()
|
||||
}
|
||||
}()
|
||||
err = fn(txn)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request.
|
||||
func MakeAuthAPI(metricsName string, deviceDB *devices.Database, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler {
|
||||
func MakeAuthAPI(metricsName string, deviceDB auth.DeviceDatabase, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler {
|
||||
h := util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
|
||||
device, resErr := auth.VerifyAccessToken(req, deviceDB)
|
||||
if resErr != nil {
|
||||
|
|
41
src/github.com/matrix-org/dendrite/common/sql.go
Normal file
41
src/github.com/matrix-org/dendrite/common/sql.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// WithTransaction runs a block of code passing in an SQL transaction
|
||||
// If the code returns an error or panics then the transactions is rolledback
|
||||
// Otherwise the transaction is committed.
|
||||
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
txn.Rollback()
|
||||
panic(r)
|
||||
} else if err != nil {
|
||||
txn.Rollback()
|
||||
} else {
|
||||
err = txn.Commit()
|
||||
}
|
||||
}()
|
||||
err = fn(txn)
|
||||
return
|
||||
}
|
|
@ -77,7 +77,7 @@ func (d *Database) UpdateRoom(
|
|||
addHosts []types.JoinedHost,
|
||||
removeHosts []string,
|
||||
) (joinedHosts []types.JoinedHost, err error) {
|
||||
err = runTransaction(d.db, func(txn *sql.Tx) error {
|
||||
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
if err = d.insertRoom(txn, roomID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -105,22 +105,3 @@ func (d *Database) UpdateRoom(
|
|||
})
|
||||
return
|
||||
}
|
||||
|
||||
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
txn.Rollback()
|
||||
panic(r)
|
||||
} else if err != nil {
|
||||
txn.Rollback()
|
||||
} else {
|
||||
err = txn.Commit()
|
||||
}
|
||||
}()
|
||||
err = fn(txn)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -92,7 +92,7 @@ func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Even
|
|||
func (d *SyncServerDatabase) WriteEvent(
|
||||
ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string,
|
||||
) (streamPos types.StreamPosition, returnErr error) {
|
||||
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs)
|
||||
if err != nil {
|
||||
|
@ -162,7 +162,7 @@ func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error)
|
|||
|
||||
// IncrementalSync returns all the data needed in order to create an incremental sync response.
|
||||
func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) {
|
||||
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
// Work out which rooms to return in the response. This is done by getting not only the currently
|
||||
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
|
||||
// This works out what the 'state' key should be for each room as well as which membership block
|
||||
|
@ -223,7 +223,7 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
|
|||
// a consistent view of the database throughout. This includes extracting the sync stream position.
|
||||
// This does have the unfortunate side-effect that all the matrixy logic resides in this function,
|
||||
// but it's better to not hide the fact that this is being done in a transaction.
|
||||
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
// Get the current stream position which we will base the sync response on.
|
||||
id, err := d.events.selectMaxID(txn)
|
||||
if err != nil {
|
||||
|
@ -479,22 +479,3 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
|
|||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
txn.Rollback()
|
||||
panic(r)
|
||||
} else if err != nil {
|
||||
txn.Rollback()
|
||||
} else {
|
||||
err = txn.Commit()
|
||||
}
|
||||
}()
|
||||
err = fn(txn)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue