Add context to the account database (#232)

This commit is contained in:
Mark Haines 2017-09-18 14:15:27 +01:00 committed by GitHub
parent 5ada8872bb
commit e28ee27605
21 changed files with 267 additions and 138 deletions

View file

@ -15,6 +15,7 @@
package accounts
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
@ -70,17 +71,22 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
return
}
func (s *accountDataStatements) insertAccountData(localpart string, roomID string, dataType string, content string) (err error) {
_, err = s.insertAccountDataStmt.Exec(localpart, roomID, dataType, content)
func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) (err error) {
stmt := s.insertAccountDataStmt
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return
}
func (s *accountDataStatements) selectAccountData(localpart string) (
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
rows, err := s.selectAccountDataStmt.Query(localpart)
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
return
}
@ -93,7 +99,7 @@ func (s *accountDataStatements) selectAccountData(localpart string) (
var dataType string
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil && err != sql.ErrNoRows {
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return
}
@ -113,11 +119,12 @@ func (s *accountDataStatements) selectAccountData(localpart string) (
}
func (s *accountDataStatements) selectAccountDataByType(
localpart string, roomID string, dataType string,
ctx context.Context, localpart, roomID, dataType string,
) (data []gomatrixserverlib.ClientEvent, err error) {
data = []gomatrixserverlib.ClientEvent{}
rows, err := s.selectAccountDataByTypeStmt.Query(localpart, roomID, dataType)
stmt := s.selectAccountDataByTypeStmt
rows, err := stmt.QueryContext(ctx, localpart, roomID, dataType)
if err != nil {
return
}

View file

@ -15,6 +15,7 @@
package accounts
import (
"context"
"database/sql"
"fmt"
"time"
@ -76,26 +77,34 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(localpart, hash string) (acc *authtypes.Account, err error) {
func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash string,
) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
if _, err = s.insertAccountStmt.Exec(localpart, createdTimeMS, hash); err == nil {
acc = &authtypes.Account{
stmt := s.insertAccountStmt
if _, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash); err != nil {
return nil, err
}
return &authtypes.Account{
Localpart: localpart,
UserID: makeUserID(localpart, s.serverName),
ServerName: s.serverName,
}
}
}, nil
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectPasswordHash(localpart string) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRow(localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectAccountByLocalpart(localpart string) (*authtypes.Account, error) {
func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Account, error) {
var acc authtypes.Account
err := s.selectAccountByLocalpartStmt.QueryRow(localpart).Scan(&acc.Localpart)
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart)
if err != nil {
acc.UserID = makeUserID(localpart, s.serverName)
acc.ServerName = s.serverName

View file

@ -15,6 +15,7 @@
package accounts
import (
"context"
"database/sql"
"github.com/lib/pq"
@ -80,18 +81,27 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
return
}
func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string, txn *sql.Tx) (err error) {
_, err = txn.Stmt(s.insertMembershipStmt).Exec(localpart, roomID, eventID)
func (s *membershipStatements) insertMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) (err error) {
stmt := txn.Stmt(s.insertMembershipStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
return
}
func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) (err error) {
_, err = txn.Stmt(s.deleteMembershipsByEventIDsStmt).Exec(pq.StringArray(eventIDs))
func (s *membershipStatements) deleteMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (err error) {
stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt)
_, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs))
return
}
func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) {
rows, err := s.selectMembershipsByLocalpartStmt.Query(localpart)
func (s *membershipStatements) selectMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
stmt := s.selectMembershipsByLocalpartStmt
rows, err := stmt.QueryContext(ctx, localpart)
if err != nil {
return
}
@ -111,7 +121,11 @@ func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (m
return
}
func (s *membershipStatements) updateMembershipByEventID(oldEventID string, newEventID string) (err error) {
_, err = s.updateMembershipByEventIDStmt.Exec(oldEventID, newEventID)
func (s *membershipStatements) updateMembershipByEventID(
ctx context.Context, oldEventID string, newEventID string,
) (err error) {
_, err = s.updateMembershipByEventIDStmt.ExecContext(
ctx, oldEventID, newEventID,
)
return
}

View file

@ -15,6 +15,7 @@
package accounts
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -71,23 +72,36 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
return
}
func (s *profilesStatements) insertProfile(localpart string) (err error) {
_, err = s.insertProfileStmt.Exec(localpart, "", "")
func (s *profilesStatements) insertProfile(
ctx context.Context, localpart string,
) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
return
}
func (s *profilesStatements) selectProfileByLocalpart(localpart string) (*authtypes.Profile, error) {
func (s *profilesStatements) selectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRow(localpart).Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL)
return &profile, err
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
}
return &profile, nil
}
func (s *profilesStatements) setAvatarURL(localpart string, avatarURL string) (err error) {
_, err = s.setAvatarURLStmt.Exec(avatarURL, localpart)
func (s *profilesStatements) setAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
}
func (s *profilesStatements) setDisplayName(localpart string, displayName string) (err error) {
_, err = s.setDisplayNameStmt.Exec(displayName, localpart)
func (s *profilesStatements) setDisplayName(
ctx context.Context, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
}

View file

@ -15,6 +15,7 @@
package accounts
import (
"context"
"database/sql"
"errors"
@ -74,46 +75,56 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(localpart, plaintextPassword string) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(localpart)
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(localpart)
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(localpart string) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(localpart)
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(localpart string, avatarURL string) error {
return d.profiles.setAvatarURL(localpart, avatarURL)
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(localpart string, displayName string) error {
return d.profiles.setDisplayName(localpart, displayName)
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account.
func (d *Database) CreateAccount(localpart, plaintextPassword string) (*authtypes.Account, error) {
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
if err := d.profiles.insertProfile(localpart); err != nil {
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
return nil, err
}
return d.accounts.insertAccount(localpart, hash)
return d.accounts.insertAccount(ctx, localpart, hash)
}
// PartitionOffsets implements common.PartitionStorer
@ -131,15 +142,19 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
// is still in the room.
// If a membership already exists between the user and the room, or of the
// insert fails, returns the SQL error
func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error {
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
func (d *Database) saveMembership(
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
}
// removeMembershipsByEventIDs removes the memberships of which the `join` membership
// event ID is included in a given array of events IDs
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
func (d *Database) removeMembershipsByEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
}
// UpdateMemberships adds the "join" membership events included in a given state
@ -147,14 +162,16 @@ func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) e
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error {
func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil {
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(event, txn); err != nil {
if err := d.newMembership(ctx, txn, event); err != nil {
return err
}
}
@ -167,8 +184,10 @@ func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsT
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(localpart)
func (d *Database) GetMembershipsByLocalpart(
ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
}
// newMembership will save a new membership in the database, with a flag on whether
@ -178,7 +197,9 @@ func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []au
// values, does nothing.
// If the event isn't a "join" membership event, does nothing
// If an error occurred, returns it
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
@ -199,7 +220,7 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error
// Only "join" membership events can be considered as new memberships
if membership == "join" {
if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err
}
}
@ -212,27 +233,33 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(localpart string, roomID string, dataType string, content string) error {
return d.accountDatas.insertAccountData(localpart, roomID, dataType, content)
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(localpart string) (
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
) {
return d.accountDatas.selectAccountData(localpart)
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns an empty array
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(localpart string, roomID string, dataType string) (data []gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(localpart, roomID, dataType)
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data []gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
func hashPassword(plaintext string) (hash string, err error) {
@ -248,9 +275,13 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, medium string) (err error) {
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(txn, threepid, medium)
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
@ -259,7 +290,7 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, me
return Err3PIDInUse
}
return d.threepids.insertThreePID(txn, threepid, medium, localpart)
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
@ -267,8 +298,10 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, me
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (err error) {
return d.threepids.deleteThreePID(threepid, medium)
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
@ -276,14 +309,18 @@ func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (er
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(threepid string, medium string) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(nil, threepid, medium)
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(localpart string) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(localpart)
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}

View file

@ -15,8 +15,11 @@
package accounts
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -76,22 +79,21 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
return
}
func (s *threepidStatements) selectLocalpartForThreePID(txn *sql.Tx, threepid string, medium string) (localpart string, err error) {
var stmt *sql.Stmt
if txn != nil {
stmt = txn.Stmt(s.selectLocalpartForThreePIDStmt)
} else {
stmt = s.selectLocalpartForThreePIDStmt
}
err = stmt.QueryRow(threepid, medium).Scan(&localpart)
func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.Query(localpart)
func (s *threepidStatements) selectThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return
}
@ -103,18 +105,25 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (thre
if err = rows.Scan(&threepid, &medium); err != nil {
return
}
threepids = append(threepids, authtypes.ThreePID{threepid, medium})
threepids = append(threepids, authtypes.ThreePID{
Address: threepid,
Medium: medium,
})
}
return
}
func (s *threepidStatements) insertThreePID(txn *sql.Tx, threepid string, medium string, localpart string) (err error) {
_, err = txn.Stmt(s.insertThreePIDStmt).Exec(threepid, medium, localpart)
func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := common.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
return
}
func (s *threepidStatements) deleteThreePID(threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.Exec(threepid, medium)
func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
return
}

View file

@ -96,7 +96,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
return err
}
if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
if err := s.db.UpdateMemberships(context.TODO(), events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
return err
}

View file

@ -59,7 +59,9 @@ func SaveAccountData(
return httputil.LogThenError(req, err)
}
if err := accountDB.SaveAccountData(localpart, roomID, dataType, string(body)); err != nil {
if err := accountDB.SaveAccountData(
req.Context(), localpart, roomID, dataType, string(body),
); err != nil {
return httputil.LogThenError(req, err)
}

View file

@ -79,7 +79,7 @@ func Login(
util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request")
acc, err := accountDB.GetAccountByPassword(r.User, r.Password)
acc, err := accountDB.GetAccountByPassword(req.Context(), r.User, r.Password)
if err != nil {
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user.

View file

@ -60,7 +60,7 @@ func GetProfile(
return httputil.LogThenError(req, err)
}
profile, err := accountDB.GetProfileByLocalpart(localpart)
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}
@ -83,7 +83,7 @@ func GetAvatarURL(
return httputil.LogThenError(req, err)
}
profile, err := accountDB.GetProfileByLocalpart(localpart)
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}
@ -127,16 +127,16 @@ func SetAvatarURL(
return httputil.LogThenError(req, err)
}
oldProfile, err := accountDB.GetProfileByLocalpart(localpart)
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}
if err = accountDB.SetAvatarURL(localpart, r.AvatarURL); err != nil {
if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil {
return httputil.LogThenError(req, err)
}
memberships, err := accountDB.GetMembershipsByLocalpart(localpart)
memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}
@ -175,7 +175,7 @@ func GetDisplayName(
return httputil.LogThenError(req, err)
}
profile, err := accountDB.GetProfileByLocalpart(localpart)
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}
@ -219,16 +219,16 @@ func SetDisplayName(
return httputil.LogThenError(req, err)
}
oldProfile, err := accountDB.GetProfileByLocalpart(localpart)
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}
if err = accountDB.SetDisplayName(localpart, r.DisplayName); err != nil {
if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil {
return httputil.LogThenError(req, err)
}
memberships, err := accountDB.GetMembershipsByLocalpart(localpart)
memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}

View file

@ -49,7 +49,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
var err error
// Check if the 3PID is already in use locally
localpart, err := accountDB.GetLocalpartForThreePID(body.Email, "email")
localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email")
if err != nil {
return httputil.LogThenError(req, err)
}
@ -64,7 +64,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
}
}
resp.SID, err = threepid.CreateSession(body, cfg)
resp.SID, err = threepid.CreateSession(req.Context(), body, cfg)
if err == threepid.ErrNotTrusted {
return util.JSONResponse{
Code: 400,
@ -91,7 +91,7 @@ func CheckAndSave3PIDAssociation(
}
// Check if the association has been validated
verified, address, medium, err := threepid.CheckAssociation(body.Creds, cfg)
verified, address, medium, err := threepid.CheckAssociation(req.Context(), body.Creds, cfg)
if err == threepid.ErrNotTrusted {
return util.JSONResponse{
Code: 400,
@ -130,7 +130,7 @@ func CheckAndSave3PIDAssociation(
return httputil.LogThenError(req, err)
}
if err = accountDB.SaveThreePIDAssociation(address, localpart, medium); err != nil {
if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil {
return httputil.LogThenError(req, err)
}
@ -149,7 +149,7 @@ func GetAssociated3PIDs(
return httputil.LogThenError(req, err)
}
threepids, err := accountDB.GetThreePIDsForLocalpart(localpart)
threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}
@ -167,7 +167,7 @@ func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONRespon
return *reqErr
}
if err := accountDB.RemoveThreePIDAssociation(body.Address, body.Medium); err != nil {
if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil {
return httputil.LogThenError(req, err)
}

View file

@ -102,7 +102,7 @@ func CheckAndProcessInvite(
return
}
lookupRes, storeInviteRes, err := queryIDServer(db, cfg, device, body, roomID)
lookupRes, storeInviteRes, err := queryIDServer(ctx, db, cfg, device, body, roomID)
if err != nil {
return
}
@ -134,6 +134,7 @@ func CheckAndProcessInvite(
// Returns a representation of the response for both cases.
// Returns an error if a check or a request failed.
func queryIDServer(
ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
@ -142,7 +143,7 @@ func queryIDServer(
}
// Lookup the 3PID
lookupRes, err = queryIDServerLookup(body)
lookupRes, err = queryIDServerLookup(ctx, body)
if err != nil {
return
}
@ -150,7 +151,7 @@ func queryIDServer(
if lookupRes.MXID == "" {
// No Matrix ID matches with the given 3PID, ask the server to store the
// invite and return a token
storeInviteRes, err = queryIDServerStoreInvite(db, cfg, device, body, roomID)
storeInviteRes, err = queryIDServerStoreInvite(ctx, db, cfg, device, body, roomID)
return
}
@ -161,11 +162,11 @@ func queryIDServer(
if lookupRes.NotBefore > now || now > lookupRes.NotAfter {
// If the current timestamp isn't in the time frame in which the association
// is known to be valid, re-run the query
return queryIDServer(db, cfg, device, body, roomID)
return queryIDServer(ctx, db, cfg, device, body, roomID)
}
// Check the request signatures and send an error if one isn't valid
if err = checkIDServerSignatures(body, lookupRes); err != nil {
if err = checkIDServerSignatures(ctx, body, lookupRes); err != nil {
return
}
@ -175,10 +176,14 @@ func queryIDServer(
// queryIDServerLookup sends a response to the identity server on /_matrix/identity/api/v1/lookup
// and returns the response as a structure.
// Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerLookup(body *MembershipRequest) (*idServerLookupResponse, error) {
func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServerLookupResponse, error) {
address := url.QueryEscape(body.Address)
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/lookup?medium=%s&address=%s", body.IDServer, body.Medium, address)
resp, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -198,6 +203,7 @@ func queryIDServerLookup(body *MembershipRequest) (*idServerLookupResponse, erro
// and returns the response as a structure.
// Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite(
ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) {
@ -209,7 +215,7 @@ func queryIDServerStoreInvite(
var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName {
profile, err = db.GetProfileByLocalpart(localpart)
profile, err = db.GetProfileByLocalpart(ctx, localpart)
if err != nil {
return nil, err
}
@ -239,7 +245,7 @@ func queryIDServerStoreInvite(
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := client.Do(req)
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -259,9 +265,13 @@ func queryIDServerStoreInvite(
// We assume that the ID server is trusted at this point.
// Returns an error if the request couldn't be sent, if its body couldn't be parsed
// or if the key couldn't be decoded from base64.
func queryIDServerPubKey(idServerName string, keyID string) ([]byte, error) {
func queryIDServerPubKey(ctx context.Context, idServerName string, keyID string) ([]byte, error) {
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/%s", idServerName, keyID)
resp, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -286,7 +296,9 @@ func queryIDServerPubKey(idServerName string, keyID string) ([]byte, error) {
// We assume that the ID server is trusted at this point.
// Returns nil if all the verifications succeeded.
// Returns an error if something failed in the process.
func checkIDServerSignatures(body *MembershipRequest, res *idServerLookupResponse) error {
func checkIDServerSignatures(
ctx context.Context, body *MembershipRequest, res *idServerLookupResponse,
) error {
// Mashall the body so we can give it to VerifyJSON
marshalledBody, err := json.Marshal(*res)
if err != nil {
@ -299,7 +311,7 @@ func checkIDServerSignatures(body *MembershipRequest, res *idServerLookupRespons
}
for keyID := range signatures {
pubKey, err := queryIDServerPubKey(body.IDServer, keyID)
pubKey, err := queryIDServerPubKey(ctx, body.IDServer, keyID)
if err != nil {
return err
}

View file

@ -15,6 +15,7 @@
package threepid
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -51,7 +52,9 @@ type Credentials struct {
// Returns the session's ID.
// Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status.
func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, error) {
func CreateSession(
ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite,
) (string, error) {
if err := isTrusted(req.IDServer, cfg); err != nil {
return "", err
}
@ -71,7 +74,7 @@ func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, er
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
client := http.Client{}
resp, err := client.Do(request)
resp, err := client.Do(request.WithContext(ctx))
if err != nil {
return "", err
}
@ -97,13 +100,19 @@ func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, er
// identifier and its medium.
// Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status.
func CheckAssociation(creds Credentials, cfg config.Dendrite) (bool, string, string, error) {
func CheckAssociation(
ctx context.Context, creds Credentials, cfg config.Dendrite,
) (bool, string, string, error) {
if err := isTrusted(creds.IDServer, cfg); err != nil {
return false, "", "", err
}
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret)
resp, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return false, "", "", err
}
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
if err != nil {
return false, "", "", err
}

View file

@ -127,7 +127,7 @@ func createRoom(req *http.Request, device *authtypes.Device,
return httputil.LogThenError(req, err)
}
profile, err := accountDB.GetProfileByLocalpart(localpart)
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}

View file

@ -57,7 +57,7 @@ func JoinRoomByIDOrAlias(
return httputil.LogThenError(req, err)
}
profile, err := accountDB.GetProfileByLocalpart(localpart)
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil {
return httputil.LogThenError(req, err)
}

View file

@ -121,7 +121,7 @@ func buildMembershipEvent(
return nil, err
}
profile, err := loadProfile(stateKey, cfg, accountDB)
profile, err := loadProfile(ctx, stateKey, cfg, accountDB)
if err != nil {
return nil, err
}
@ -156,7 +156,9 @@ func buildMembershipEvent(
// it if the user is local to this server, or returns an empty profile if not.
// Returns an error if the retrieval failed or if the first parameter isn't a
// valid Matrix ID.
func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Database) (*authtypes.Profile, error) {
func loadProfile(
ctx context.Context, userID string, cfg config.Dendrite, accountDB *accounts.Database,
) (*authtypes.Profile, error) {
localpart, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return nil, err
@ -164,7 +166,7 @@ func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Databas
var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName {
profile, err = accountDB.GetProfileByLocalpart(localpart)
profile, err = accountDB.GetProfileByLocalpart(ctx, localpart)
} else {
profile = &authtypes.Profile{}
}

View file

@ -1,6 +1,7 @@
package writers
import (
"context"
"fmt"
"net/http"
"time"
@ -134,7 +135,9 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices
switch r.Auth.Type {
case authtypes.LoginTypeDummy:
// there is nothing to do
return completeRegistration(accountDB, deviceDB, r.Username, r.Password)
return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, r.Password,
)
default:
return util.JSONResponse{
Code: 501,
@ -143,7 +146,12 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices
}
}
func completeRegistration(accountDB *accounts.Database, deviceDB *devices.Database, username, password string) util.JSONResponse {
func completeRegistration(
ctx context.Context,
accountDB *accounts.Database,
deviceDB *devices.Database,
username, password string,
) util.JSONResponse {
if username == "" {
return util.JSONResponse{
Code: 400,
@ -157,7 +165,7 @@ func completeRegistration(accountDB *accounts.Database, deviceDB *devices.Databa
}
}
acc, err := accountDB.CreateAccount(username, password)
acc, err := accountDB.CreateAccount(ctx, username, password)
if err != nil {
return util.JSONResponse{
Code: 500,

View file

@ -15,6 +15,7 @@
package main
import (
"context"
"flag"
"fmt"
"os"
@ -68,7 +69,7 @@ func main() {
os.Exit(1)
}
_, err = accountDB.CreateAccount(*username, *password)
_, err = accountDB.CreateAccount(context.Background(), *username, *password)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)

View file

@ -191,7 +191,7 @@ func createInviteFrom3PIDInvite(
StateKey: &inv.MXID,
}
profile, err := accountDB.GetProfileByLocalpart(localpart)
profile, err := accountDB.GetProfileByLocalpart(ctx, localpart)
if err != nil {
return nil, err
}

View file

@ -15,6 +15,7 @@
package sync
import (
"context"
"net/http"
"strconv"
"time"
@ -29,6 +30,7 @@ const defaultTimelineLimit = 20
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
type syncRequest struct {
ctx context.Context
userID string
limit int
timeout time.Duration
@ -47,6 +49,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
}
// TODO: Additional query params: set_presence, filter
return &syncRequest{
ctx: req.Context(),
userID: userID,
timeout: timeout,
since: since,

View file

@ -128,7 +128,7 @@ func (rp *RequestPool) appendAccountData(
// already been sent. Instead, we send the whole batch.
var global []gomatrixserverlib.ClientEvent
var rooms map[string][]gomatrixserverlib.ClientEvent
global, rooms, err = rp.accountDB.GetAccountData(localpart)
global, rooms, err = rp.accountDB.GetAccountData(req.ctx, localpart)
if err != nil {
return nil, err
}
@ -159,7 +159,9 @@ func (rp *RequestPool) appendAccountData(
events := []gomatrixserverlib.ClientEvent{}
// Request the missing data from the database
for _, dataType := range dataTypes {
evs, err := rp.accountDB.GetAccountDataByType(localpart, roomID, dataType)
evs, err := rp.accountDB.GetAccountDataByType(
req.ctx, localpart, roomID, dataType,
)
if err != nil {
return nil, err
}