Add context to the account database (#232)

This commit is contained in:
Mark Haines 2017-09-18 14:15:27 +01:00 committed by GitHub
parent 5ada8872bb
commit e28ee27605
21 changed files with 267 additions and 138 deletions

View file

@ -15,6 +15,7 @@
package accounts package accounts
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -70,17 +71,22 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
return return
} }
func (s *accountDataStatements) insertAccountData(localpart string, roomID string, dataType string, content string) (err error) { func (s *accountDataStatements) insertAccountData(
_, err = s.insertAccountDataStmt.Exec(localpart, roomID, dataType, content) ctx context.Context, localpart, roomID, dataType, content string,
) (err error) {
stmt := s.insertAccountDataStmt
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }
func (s *accountDataStatements) selectAccountData(localpart string) ( func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
global []gomatrixserverlib.ClientEvent, global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent,
err error, err error,
) { ) {
rows, err := s.selectAccountDataStmt.Query(localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return return
} }
@ -93,7 +99,7 @@ func (s *accountDataStatements) selectAccountData(localpart string) (
var dataType string var dataType string
var content []byte var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil && err != sql.ErrNoRows { if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return return
} }
@ -113,11 +119,12 @@ func (s *accountDataStatements) selectAccountData(localpart string) (
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) selectAccountDataByType(
localpart string, roomID string, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data []gomatrixserverlib.ClientEvent, err error) { ) (data []gomatrixserverlib.ClientEvent, err error) {
data = []gomatrixserverlib.ClientEvent{} data = []gomatrixserverlib.ClientEvent{}
rows, err := s.selectAccountDataByTypeStmt.Query(localpart, roomID, dataType) stmt := s.selectAccountDataByTypeStmt
rows, err := stmt.QueryContext(ctx, localpart, roomID, dataType)
if err != nil { if err != nil {
return return
} }

View file

@ -15,6 +15,7 @@
package accounts package accounts
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"time" "time"
@ -76,26 +77,34 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount(localpart, hash string) (acc *authtypes.Account, err error) { func (s *accountsStatements) insertAccount(
ctx context.Context, localpart, hash string,
) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
if _, err = s.insertAccountStmt.Exec(localpart, createdTimeMS, hash); err == nil { stmt := s.insertAccountStmt
acc = &authtypes.Account{ if _, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash); err != nil {
return nil, err
}
return &authtypes.Account{
Localpart: localpart, Localpart: localpart,
UserID: makeUserID(localpart, s.serverName), UserID: makeUserID(localpart, s.serverName),
ServerName: s.serverName, ServerName: s.serverName,
} }, nil
} }
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return return
} }
func (s *accountsStatements) selectPasswordHash(localpart string) (hash string, err error) { func (s *accountsStatements) selectAccountByLocalpart(
err = s.selectPasswordHashStmt.QueryRow(localpart).Scan(&hash) ctx context.Context, localpart string,
return ) (*authtypes.Account, error) {
}
func (s *accountsStatements) selectAccountByLocalpart(localpart string) (*authtypes.Account, error) {
var acc authtypes.Account var acc authtypes.Account
err := s.selectAccountByLocalpartStmt.QueryRow(localpart).Scan(&acc.Localpart) stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart)
if err != nil { if err != nil {
acc.UserID = makeUserID(localpart, s.serverName) acc.UserID = makeUserID(localpart, s.serverName)
acc.ServerName = s.serverName acc.ServerName = s.serverName

View file

@ -15,6 +15,7 @@
package accounts package accounts
import ( import (
"context"
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
@ -80,18 +81,27 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
return return
} }
func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string, txn *sql.Tx) (err error) { func (s *membershipStatements) insertMembership(
_, err = txn.Stmt(s.insertMembershipStmt).Exec(localpart, roomID, eventID) ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) (err error) {
stmt := txn.Stmt(s.insertMembershipStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
return return
} }
func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) (err error) { func (s *membershipStatements) deleteMembershipsByEventIDs(
_, err = txn.Stmt(s.deleteMembershipsByEventIDsStmt).Exec(pq.StringArray(eventIDs)) ctx context.Context, txn *sql.Tx, eventIDs []string,
) (err error) {
stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt)
_, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs))
return return
} }
func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { func (s *membershipStatements) selectMembershipsByLocalpart(
rows, err := s.selectMembershipsByLocalpartStmt.Query(localpart) ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
stmt := s.selectMembershipsByLocalpartStmt
rows, err := stmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return return
} }
@ -111,7 +121,11 @@ func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (m
return return
} }
func (s *membershipStatements) updateMembershipByEventID(oldEventID string, newEventID string) (err error) { func (s *membershipStatements) updateMembershipByEventID(
_, err = s.updateMembershipByEventIDStmt.Exec(oldEventID, newEventID) ctx context.Context, oldEventID string, newEventID string,
) (err error) {
_, err = s.updateMembershipByEventIDStmt.ExecContext(
ctx, oldEventID, newEventID,
)
return return
} }

View file

@ -15,6 +15,7 @@
package accounts package accounts
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -71,23 +72,36 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
return return
} }
func (s *profilesStatements) insertProfile(localpart string) (err error) { func (s *profilesStatements) insertProfile(
_, err = s.insertProfileStmt.Exec(localpart, "", "") ctx context.Context, localpart string,
) (err error) {
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
return return
} }
func (s *profilesStatements) selectProfileByLocalpart(localpart string) (*authtypes.Profile, error) { func (s *profilesStatements) selectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRow(localpart).Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL) err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
return &profile, err &profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
}
return &profile, nil
} }
func (s *profilesStatements) setAvatarURL(localpart string, avatarURL string) (err error) { func (s *profilesStatements) setAvatarURL(
_, err = s.setAvatarURLStmt.Exec(avatarURL, localpart) ctx context.Context, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return return
} }
func (s *profilesStatements) setDisplayName(localpart string, displayName string) (err error) { func (s *profilesStatements) setDisplayName(
_, err = s.setDisplayNameStmt.Exec(displayName, localpart) ctx context.Context, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return return
} }

View file

@ -15,6 +15,7 @@
package accounts package accounts
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
@ -74,46 +75,56 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(localpart, plaintextPassword string) (*authtypes.Account, error) { func (d *Database) GetAccountByPassword(
hash, err := d.accounts.selectPasswordHash(localpart) ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err return nil, err
} }
return d.accounts.selectAccountByLocalpart(localpart) return d.accounts.selectAccountByLocalpart(ctx, localpart)
} }
// GetProfileByLocalpart returns the profile associated with the given localpart. // GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(localpart string) (*authtypes.Profile, error) { func (d *Database) GetProfileByLocalpart(
return d.profiles.selectProfileByLocalpart(localpart) ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
} }
// SetAvatarURL updates the avatar URL of the profile associated with the given // SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(localpart string, avatarURL string) error { func (d *Database) SetAvatarURL(
return d.profiles.setAvatarURL(localpart, avatarURL) ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
} }
// SetDisplayName updates the display name of the profile associated with the given // SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(localpart string, displayName string) error { func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(localpart, displayName) ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. // for this account. If no password is supplied, the account will be a passwordless account.
func (d *Database) CreateAccount(localpart, plaintextPassword string) (*authtypes.Account, error) { func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) {
hash, err := hashPassword(plaintextPassword) hash, err := hashPassword(plaintextPassword)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := d.profiles.insertProfile(localpart); err != nil { if err := d.profiles.insertProfile(ctx, localpart); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(localpart, hash) return d.accounts.insertAccount(ctx, localpart, hash)
} }
// PartitionOffsets implements common.PartitionStorer // PartitionOffsets implements common.PartitionStorer
@ -131,15 +142,19 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
// is still in the room. // is still in the room.
// If a membership already exists between the user and the room, or of the // If a membership already exists between the user and the room, or of the
// insert fails, returns the SQL error // insert fails, returns the SQL error
func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error { func (d *Database) saveMembership(
return d.memberships.insertMembership(localpart, roomID, eventID, txn) ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
) error {
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
} }
// removeMembershipsByEventIDs removes the memberships of which the `join` membership // removeMembershipsByEventIDs removes the memberships of which the `join` membership
// event ID is included in a given array of events IDs // event ID is included in a given array of events IDs
// If the removal fails, or if there is no membership to remove, returns an error // If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error { func (d *Database) removeMembershipsByEventIDs(
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn) ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
} }
// UpdateMemberships adds the "join" membership events included in a given state // UpdateMemberships adds the "join" membership events included in a given state
@ -147,14 +162,16 @@ func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) e
// IDs. All of the process is run in a transaction, which commits only once/if every // IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed. // insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process // Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error { func (d *Database) UpdateMemberships(
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error { return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil { if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
return err return err
} }
for _, event := range eventsToAdd { for _, event := range eventsToAdd {
if err := d.newMembership(event, txn); err != nil { if err := d.newMembership(ctx, txn, event); err != nil {
return err return err
} }
} }
@ -167,8 +184,10 @@ func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsT
// the rooms a user matching a given localpart is a member of // the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array // If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error // If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { func (d *Database) GetMembershipsByLocalpart(
return d.memberships.selectMembershipsByLocalpart(localpart) ctx context.Context, localpart string,
) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
} }
// newMembership will save a new membership in the database, with a flag on whether // newMembership will save a new membership in the database, with a flag on whether
@ -178,7 +197,9 @@ func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []au
// values, does nothing. // values, does nothing.
// If the event isn't a "join" membership event, does nothing // If the event isn't a "join" membership event, does nothing
// If an error occurred, returns it // If an error occurred, returns it
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error { func (d *Database) newMembership(
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil { if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil { if err != nil {
@ -199,7 +220,7 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error
// Only "join" membership events can be considered as new memberships // Only "join" membership events can be considered as new memberships
if membership == "join" { if membership == "join" {
if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil { if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err return err
} }
} }
@ -212,27 +233,33 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error
// If an account data already exists for a given set (user, room, data type), it will // If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content // update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update // Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(localpart string, roomID string, dataType string, content string) error { func (d *Database) SaveAccountData(
return d.accountDatas.insertAccountData(localpart, roomID, dataType, content) ctx context.Context, localpart, roomID, dataType, content string,
) error {
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays // If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(localpart string) ( func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent, global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent,
err error, err error,
) { ) {
return d.accountDatas.selectAccountData(localpart) return d.accountDatas.selectAccountData(ctx, localpart)
} }
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
// localpart, room ID and type. // localpart, room ID and type.
// If no account data could be found, returns an empty array // If no account data could be found, returns an empty array
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(localpart string, roomID string, dataType string) (data []gomatrixserverlib.ClientEvent, err error) { func (d *Database) GetAccountDataByType(
return d.accountDatas.selectAccountDataByType(localpart, roomID, dataType) ctx context.Context, localpart, roomID, dataType string,
) (data []gomatrixserverlib.ClientEvent, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
} }
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {
@ -248,9 +275,13 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use")
// and a local Matrix user (identified by the user's ID's local part). // and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse. // If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, medium string) (err error) { func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return common.WithTransaction(d.db, func(txn *sql.Tx) error { return common.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(txn, threepid, medium) user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil { if err != nil {
return err return err
} }
@ -259,7 +290,7 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, me
return Err3PIDInUse return Err3PIDInUse
} }
return d.threepids.insertThreePID(txn, threepid, medium, localpart) return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
}) })
} }
@ -267,8 +298,10 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, me
// identifier. // identifier.
// If no association exists involving this third-party identifier, returns nothing. // If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error. // If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (err error) { func (d *Database) RemoveThreePIDAssociation(
return d.threepids.deleteThreePID(threepid, medium) ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
} }
// GetLocalpartForThreePID looks up the localpart associated with a given third-party // GetLocalpartForThreePID looks up the localpart associated with a given third-party
@ -276,14 +309,18 @@ func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (er
// If no association involves the given third-party idenfitier, returns an empty // If no association involves the given third-party idenfitier, returns an empty
// string. // string.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(threepid string, medium string) (localpart string, err error) { func (d *Database) GetLocalpartForThreePID(
return d.threepids.selectLocalpartForThreePID(nil, threepid, medium) ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
} }
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with // GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user. // a given local user.
// If no association is known for this user, returns an empty slice. // If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database. // Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(localpart string) (threepids []authtypes.ThreePID, err error) { func (d *Database) GetThreePIDsForLocalpart(
return d.threepids.selectThreePIDsForLocalpart(localpart) ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
} }

View file

@ -15,8 +15,11 @@
package accounts package accounts
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -76,22 +79,21 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
return return
} }
func (s *threepidStatements) selectLocalpartForThreePID(txn *sql.Tx, threepid string, medium string) (localpart string, err error) { func (s *threepidStatements) selectLocalpartForThreePID(
var stmt *sql.Stmt ctx context.Context, txn *sql.Tx, threepid string, medium string,
if txn != nil { ) (localpart string, err error) {
stmt = txn.Stmt(s.selectLocalpartForThreePIDStmt) stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
} else { err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
stmt = s.selectLocalpartForThreePIDStmt
}
err = stmt.QueryRow(threepid, medium).Scan(&localpart)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
return return
} }
func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (threepids []authtypes.ThreePID, err error) { func (s *threepidStatements) selectThreePIDsForLocalpart(
rows, err := s.selectThreePIDsForLocalpartStmt.Query(localpart) ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return return
} }
@ -103,18 +105,25 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (thre
if err = rows.Scan(&threepid, &medium); err != nil { if err = rows.Scan(&threepid, &medium); err != nil {
return return
} }
threepids = append(threepids, authtypes.ThreePID{threepid, medium}) threepids = append(threepids, authtypes.ThreePID{
Address: threepid,
Medium: medium,
})
} }
return return
} }
func (s *threepidStatements) insertThreePID(txn *sql.Tx, threepid string, medium string, localpart string) (err error) { func (s *threepidStatements) insertThreePID(
_, err = txn.Stmt(s.insertThreePIDStmt).Exec(threepid, medium, localpart) ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := common.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
return return
} }
func (s *threepidStatements) deleteThreePID(threepid string, medium string) (err error) { func (s *threepidStatements) deleteThreePID(
_, err = s.deleteThreePIDStmt.Exec(threepid, medium) ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
return return
} }

View file

@ -96,7 +96,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
return err return err
} }
if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil { if err := s.db.UpdateMemberships(context.TODO(), events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
return err return err
} }

View file

@ -59,7 +59,9 @@ func SaveAccountData(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
if err := accountDB.SaveAccountData(localpart, roomID, dataType, string(body)); err != nil { if err := accountDB.SaveAccountData(
req.Context(), localpart, roomID, dataType, string(body),
); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -79,7 +79,7 @@ func Login(
util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request") util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request")
acc, err := accountDB.GetAccountByPassword(r.User, r.Password) acc, err := accountDB.GetAccountByPassword(req.Context(), r.User, r.Password)
if err != nil { if err != nil {
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user. // but that would leak the existence of the user.

View file

@ -60,7 +60,7 @@ func GetProfile(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
profile, err := accountDB.GetProfileByLocalpart(localpart) profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -83,7 +83,7 @@ func GetAvatarURL(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
profile, err := accountDB.GetProfileByLocalpart(localpart) profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -127,16 +127,16 @@ func SetAvatarURL(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
oldProfile, err := accountDB.GetProfileByLocalpart(localpart) oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
if err = accountDB.SetAvatarURL(localpart, r.AvatarURL); err != nil { if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
memberships, err := accountDB.GetMembershipsByLocalpart(localpart) memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -175,7 +175,7 @@ func GetDisplayName(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
profile, err := accountDB.GetProfileByLocalpart(localpart) profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -219,16 +219,16 @@ func SetDisplayName(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
oldProfile, err := accountDB.GetProfileByLocalpart(localpart) oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
if err = accountDB.SetDisplayName(localpart, r.DisplayName); err != nil { if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
memberships, err := accountDB.GetMembershipsByLocalpart(localpart) memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -49,7 +49,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
var err error var err error
// Check if the 3PID is already in use locally // Check if the 3PID is already in use locally
localpart, err := accountDB.GetLocalpartForThreePID(body.Email, "email") localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email")
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -64,7 +64,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
} }
} }
resp.SID, err = threepid.CreateSession(body, cfg) resp.SID, err = threepid.CreateSession(req.Context(), body, cfg)
if err == threepid.ErrNotTrusted { if err == threepid.ErrNotTrusted {
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
@ -91,7 +91,7 @@ func CheckAndSave3PIDAssociation(
} }
// Check if the association has been validated // Check if the association has been validated
verified, address, medium, err := threepid.CheckAssociation(body.Creds, cfg) verified, address, medium, err := threepid.CheckAssociation(req.Context(), body.Creds, cfg)
if err == threepid.ErrNotTrusted { if err == threepid.ErrNotTrusted {
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
@ -130,7 +130,7 @@ func CheckAndSave3PIDAssociation(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
if err = accountDB.SaveThreePIDAssociation(address, localpart, medium); err != nil { if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -149,7 +149,7 @@ func GetAssociated3PIDs(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
threepids, err := accountDB.GetThreePIDsForLocalpart(localpart) threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -167,7 +167,7 @@ func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONRespon
return *reqErr return *reqErr
} }
if err := accountDB.RemoveThreePIDAssociation(body.Address, body.Medium); err != nil { if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -102,7 +102,7 @@ func CheckAndProcessInvite(
return return
} }
lookupRes, storeInviteRes, err := queryIDServer(db, cfg, device, body, roomID) lookupRes, storeInviteRes, err := queryIDServer(ctx, db, cfg, device, body, roomID)
if err != nil { if err != nil {
return return
} }
@ -134,6 +134,7 @@ func CheckAndProcessInvite(
// Returns a representation of the response for both cases. // Returns a representation of the response for both cases.
// Returns an error if a check or a request failed. // Returns an error if a check or a request failed.
func queryIDServer( func queryIDServer(
ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
@ -142,7 +143,7 @@ func queryIDServer(
} }
// Lookup the 3PID // Lookup the 3PID
lookupRes, err = queryIDServerLookup(body) lookupRes, err = queryIDServerLookup(ctx, body)
if err != nil { if err != nil {
return return
} }
@ -150,7 +151,7 @@ func queryIDServer(
if lookupRes.MXID == "" { if lookupRes.MXID == "" {
// No Matrix ID matches with the given 3PID, ask the server to store the // No Matrix ID matches with the given 3PID, ask the server to store the
// invite and return a token // invite and return a token
storeInviteRes, err = queryIDServerStoreInvite(db, cfg, device, body, roomID) storeInviteRes, err = queryIDServerStoreInvite(ctx, db, cfg, device, body, roomID)
return return
} }
@ -161,11 +162,11 @@ func queryIDServer(
if lookupRes.NotBefore > now || now > lookupRes.NotAfter { if lookupRes.NotBefore > now || now > lookupRes.NotAfter {
// If the current timestamp isn't in the time frame in which the association // If the current timestamp isn't in the time frame in which the association
// is known to be valid, re-run the query // is known to be valid, re-run the query
return queryIDServer(db, cfg, device, body, roomID) return queryIDServer(ctx, db, cfg, device, body, roomID)
} }
// Check the request signatures and send an error if one isn't valid // Check the request signatures and send an error if one isn't valid
if err = checkIDServerSignatures(body, lookupRes); err != nil { if err = checkIDServerSignatures(ctx, body, lookupRes); err != nil {
return return
} }
@ -175,10 +176,14 @@ func queryIDServer(
// queryIDServerLookup sends a response to the identity server on /_matrix/identity/api/v1/lookup // queryIDServerLookup sends a response to the identity server on /_matrix/identity/api/v1/lookup
// and returns the response as a structure. // and returns the response as a structure.
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerLookup(body *MembershipRequest) (*idServerLookupResponse, error) { func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServerLookupResponse, error) {
address := url.QueryEscape(body.Address) address := url.QueryEscape(body.Address)
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/lookup?medium=%s&address=%s", body.IDServer, body.Medium, address) url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/lookup?medium=%s&address=%s", body.IDServer, body.Medium, address)
resp, err := http.Get(url) req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -198,6 +203,7 @@ func queryIDServerLookup(body *MembershipRequest) (*idServerLookupResponse, erro
// and returns the response as a structure. // and returns the response as a structure.
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite( func queryIDServerStoreInvite(
ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) { ) (*idServerStoreInviteResponse, error) {
@ -209,7 +215,7 @@ func queryIDServerStoreInvite(
var profile *authtypes.Profile var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
profile, err = db.GetProfileByLocalpart(localpart) profile, err = db.GetProfileByLocalpart(ctx, localpart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -239,7 +245,7 @@ func queryIDServerStoreInvite(
} }
req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := client.Do(req) resp, err := client.Do(req.WithContext(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -259,9 +265,13 @@ func queryIDServerStoreInvite(
// We assume that the ID server is trusted at this point. // We assume that the ID server is trusted at this point.
// Returns an error if the request couldn't be sent, if its body couldn't be parsed // Returns an error if the request couldn't be sent, if its body couldn't be parsed
// or if the key couldn't be decoded from base64. // or if the key couldn't be decoded from base64.
func queryIDServerPubKey(idServerName string, keyID string) ([]byte, error) { func queryIDServerPubKey(ctx context.Context, idServerName string, keyID string) ([]byte, error) {
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/%s", idServerName, keyID) url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/%s", idServerName, keyID)
resp, err := http.Get(url) req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -286,7 +296,9 @@ func queryIDServerPubKey(idServerName string, keyID string) ([]byte, error) {
// We assume that the ID server is trusted at this point. // We assume that the ID server is trusted at this point.
// Returns nil if all the verifications succeeded. // Returns nil if all the verifications succeeded.
// Returns an error if something failed in the process. // Returns an error if something failed in the process.
func checkIDServerSignatures(body *MembershipRequest, res *idServerLookupResponse) error { func checkIDServerSignatures(
ctx context.Context, body *MembershipRequest, res *idServerLookupResponse,
) error {
// Mashall the body so we can give it to VerifyJSON // Mashall the body so we can give it to VerifyJSON
marshalledBody, err := json.Marshal(*res) marshalledBody, err := json.Marshal(*res)
if err != nil { if err != nil {
@ -299,7 +311,7 @@ func checkIDServerSignatures(body *MembershipRequest, res *idServerLookupRespons
} }
for keyID := range signatures { for keyID := range signatures {
pubKey, err := queryIDServerPubKey(body.IDServer, keyID) pubKey, err := queryIDServerPubKey(ctx, body.IDServer, keyID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -15,6 +15,7 @@
package threepid package threepid
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -51,7 +52,9 @@ type Credentials struct {
// Returns the session's ID. // Returns the session's ID.
// Returns an error if there was a problem sending the request or decoding the // Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status. // response, or if the identity server responded with a non-OK status.
func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, error) { func CreateSession(
ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite,
) (string, error) {
if err := isTrusted(req.IDServer, cfg); err != nil { if err := isTrusted(req.IDServer, cfg); err != nil {
return "", err return "", err
} }
@ -71,7 +74,7 @@ func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, er
request.Header.Add("Content-Type", "application/x-www-form-urlencoded") request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
client := http.Client{} client := http.Client{}
resp, err := client.Do(request) resp, err := client.Do(request.WithContext(ctx))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -97,13 +100,19 @@ func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, er
// identifier and its medium. // identifier and its medium.
// Returns an error if there was a problem sending the request or decoding the // Returns an error if there was a problem sending the request or decoding the
// response, or if the identity server responded with a non-OK status. // response, or if the identity server responded with a non-OK status.
func CheckAssociation(creds Credentials, cfg config.Dendrite) (bool, string, string, error) { func CheckAssociation(
ctx context.Context, creds Credentials, cfg config.Dendrite,
) (bool, string, string, error) {
if err := isTrusted(creds.IDServer, cfg); err != nil { if err := isTrusted(creds.IDServer, cfg); err != nil {
return false, "", "", err return false, "", "", err
} }
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret) url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret)
resp, err := http.Get(url) req, err := http.NewRequest("GET", url, nil)
if err != nil {
return false, "", "", err
}
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
if err != nil { if err != nil {
return false, "", "", err return false, "", "", err
} }

View file

@ -127,7 +127,7 @@ func createRoom(req *http.Request, device *authtypes.Device,
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
profile, err := accountDB.GetProfileByLocalpart(localpart) profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -57,7 +57,7 @@ func JoinRoomByIDOrAlias(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
profile, err := accountDB.GetProfileByLocalpart(localpart) profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -121,7 +121,7 @@ func buildMembershipEvent(
return nil, err return nil, err
} }
profile, err := loadProfile(stateKey, cfg, accountDB) profile, err := loadProfile(ctx, stateKey, cfg, accountDB)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -156,7 +156,9 @@ func buildMembershipEvent(
// it if the user is local to this server, or returns an empty profile if not. // it if the user is local to this server, or returns an empty profile if not.
// Returns an error if the retrieval failed or if the first parameter isn't a // Returns an error if the retrieval failed or if the first parameter isn't a
// valid Matrix ID. // valid Matrix ID.
func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Database) (*authtypes.Profile, error) { func loadProfile(
ctx context.Context, userID string, cfg config.Dendrite, accountDB *accounts.Database,
) (*authtypes.Profile, error) {
localpart, serverName, err := gomatrixserverlib.SplitID('@', userID) localpart, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -164,7 +166,7 @@ func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Databas
var profile *authtypes.Profile var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
profile, err = accountDB.GetProfileByLocalpart(localpart) profile, err = accountDB.GetProfileByLocalpart(ctx, localpart)
} else { } else {
profile = &authtypes.Profile{} profile = &authtypes.Profile{}
} }

View file

@ -1,6 +1,7 @@
package writers package writers
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
@ -134,7 +135,9 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
return completeRegistration(accountDB, deviceDB, r.Username, r.Password) return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, r.Password,
)
default: default:
return util.JSONResponse{ return util.JSONResponse{
Code: 501, Code: 501,
@ -143,7 +146,12 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices
} }
} }
func completeRegistration(accountDB *accounts.Database, deviceDB *devices.Database, username, password string) util.JSONResponse { func completeRegistration(
ctx context.Context,
accountDB *accounts.Database,
deviceDB *devices.Database,
username, password string,
) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
@ -157,7 +165,7 @@ func completeRegistration(accountDB *accounts.Database, deviceDB *devices.Databa
} }
} }
acc, err := accountDB.CreateAccount(username, password) acc, err := accountDB.CreateAccount(ctx, username, password)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 500, Code: 500,

View file

@ -15,6 +15,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"os" "os"
@ -68,7 +69,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
_, err = accountDB.CreateAccount(*username, *password) _, err = accountDB.CreateAccount(context.Background(), *username, *password)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(1) os.Exit(1)

View file

@ -191,7 +191,7 @@ func createInviteFrom3PIDInvite(
StateKey: &inv.MXID, StateKey: &inv.MXID,
} }
profile, err := accountDB.GetProfileByLocalpart(localpart) profile, err := accountDB.GetProfileByLocalpart(ctx, localpart)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -15,6 +15,7 @@
package sync package sync
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -29,6 +30,7 @@ const defaultTimelineLimit = 20
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied. // syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
type syncRequest struct { type syncRequest struct {
ctx context.Context
userID string userID string
limit int limit int
timeout time.Duration timeout time.Duration
@ -47,6 +49,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
} }
// TODO: Additional query params: set_presence, filter // TODO: Additional query params: set_presence, filter
return &syncRequest{ return &syncRequest{
ctx: req.Context(),
userID: userID, userID: userID,
timeout: timeout, timeout: timeout,
since: since, since: since,

View file

@ -128,7 +128,7 @@ func (rp *RequestPool) appendAccountData(
// already been sent. Instead, we send the whole batch. // already been sent. Instead, we send the whole batch.
var global []gomatrixserverlib.ClientEvent var global []gomatrixserverlib.ClientEvent
var rooms map[string][]gomatrixserverlib.ClientEvent var rooms map[string][]gomatrixserverlib.ClientEvent
global, rooms, err = rp.accountDB.GetAccountData(localpart) global, rooms, err = rp.accountDB.GetAccountData(req.ctx, localpart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -159,7 +159,9 @@ func (rp *RequestPool) appendAccountData(
events := []gomatrixserverlib.ClientEvent{} events := []gomatrixserverlib.ClientEvent{}
// Request the missing data from the database // Request the missing data from the database
for _, dataType := range dataTypes { for _, dataType := range dataTypes {
evs, err := rp.accountDB.GetAccountDataByType(localpart, roomID, dataType) evs, err := rp.accountDB.GetAccountDataByType(
req.ctx, localpart, roomID, dataType,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }