Add device ID and revoke old tokens

This commit is contained in:
Kegan Dougal 2017-05-24 16:30:06 +01:00
parent 1e8f174837
commit dcb0d995ce
3 changed files with 83 additions and 13 deletions

View file

@ -16,6 +16,7 @@
package auth package auth
import ( import (
"database/sql"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -40,9 +41,16 @@ func VerifyAccessToken(req *http.Request, deviceDB *devices.Database) (device *a
} }
device, err = deviceDB.GetDeviceByAccessToken(token) device, err = deviceDB.GetDeviceByAccessToken(token)
if err != nil { if err != nil {
resErr = &util.JSONResponse{ if err == sql.ErrNoRows {
Code: 500, resErr = &util.JSONResponse{
JSON: jsonerror.Unknown("Failed to check access token"), Code: 403,
JSON: jsonerror.Unknown("Invalid access token"),
}
} else {
resErr = &util.JSONResponse{
Code: 500,
JSON: jsonerror.Unknown("Failed to check access token"),
}
} }
} }
return return

View file

@ -29,25 +29,37 @@ CREATE TABLE IF NOT EXISTS devices (
-- The access token granted to this device. This has to be the primary key -- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request. -- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY, access_token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID localpart for this device -- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user.
id TEXT NOT NULL,
-- The Matrix user ID localpart for this device. This is preferable to storing the full user_id
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
-- migration to different domain names easier.
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution). -- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL created_ts BIGINT NOT NULL
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app) -- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
); );
-- Device IDs must be unique for a given user.
CREATE UNIQUE INDEX IF NOT EXISTS localpart_id_idx ON devices(localpart, id);
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO devices(access_token, localpart, created_ts) VALUES ($1, $2, $3)" "INSERT INTO devices(id, localpart, access_token, created_ts) VALUES ($1, $2, $3, $4)"
const selectDeviceByTokenSQL = "" + const selectDeviceByTokenSQL = "" +
"SELECT localpart FROM devices WHERE access_token = $1" "SELECT id, localpart FROM devices WHERE access_token = $1"
const deleteDeviceSQL = "" +
"DELETE FROM devices WHERE id = $1 AND localpart = $2"
// TODO: List devices, delete device API // TODO: List devices, delete device API
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -62,15 +74,19 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return return
} }
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
// insertDevice creates a new device. Returns an error if a device with the same access token already exists. // insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice(localpart, accessToken string) (dev *authtypes.Device, err error) { func (s *devicesStatements) insertDevice(txn *sql.Tx, id, localpart, accessToken string) (dev *authtypes.Device, err error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
if _, err = s.insertDeviceStmt.Exec(accessToken, localpart, createdTimeMS); err == nil { if _, err = s.insertDeviceStmt.Exec(id, localpart, accessToken, createdTimeMS); err == nil {
dev = &authtypes.Device{ dev = &authtypes.Device{
UserID: makeUserID(localpart, s.serverName), UserID: makeUserID(localpart, s.serverName),
AccessToken: accessToken, AccessToken: accessToken,
@ -79,6 +95,11 @@ func (s *devicesStatements) insertDevice(localpart, accessToken string) (dev *au
return return
} }
func (s *devicesStatements) deleteDevice(txn *sql.Tx, id, localpart string) error {
_, err := txn.Stmt(s.deleteDeviceStmt).Exec(id, localpart)
return err
}
func (s *devicesStatements) selectDeviceByToken(accessToken string) (*authtypes.Device, error) { func (s *devicesStatements) selectDeviceByToken(accessToken string) (*authtypes.Device, error) {
var dev authtypes.Device var dev authtypes.Device
var localpart string var localpart string

View file

@ -42,9 +42,50 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
} }
// GetDeviceByAccessToken returns the device matching the given access token. // GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, error) { func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, error) {
// TODO: Actual implementation return d.devices.selectDeviceByToken(token)
return &authtypes.Device{ }
UserID: token,
}, nil // CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with a newly generated token.
// Returns the device on success.
func (d *Database) CreateDevice(localpart, deviceID string) (dev *authtypes.Device, returnErr error) {
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing token for this device
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil {
return err
}
// TODO: generate an access token. We should probably make sure that it's not possible for this
// token to be the same as the one we just revoked...
accessToken := makeUserID(localpart, d.devices.serverName)
dev, err = d.devices.insertDevice(txn, deviceID, localpart, accessToken)
if err != nil {
return err
}
return nil
})
return
}
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
} }