use of interfaces for function calls in place of dynamic DB pointers

Signed-off-by: MohitKS5 <mohitkumarsingh907@gmail.com>
This commit is contained in:
MohitKS5 2018-03-11 10:12:27 +05:30
parent 6b55972183
commit 653f20995d
13 changed files with 127 additions and 61 deletions

View file

@ -15,22 +15,26 @@
package routing package routing
import ( import (
"context"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// interface for database layer
type accountsData interface {
SaveAccountData(context.Context, string, string, string, string) error
}
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
func SaveAccountData( func SaveAccountData(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB accountsData, device *authtypes.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != "PUT" { if req.Method != "PUT" {

View file

@ -15,25 +15,29 @@
package routing package routing
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// interface to Database Layer
type createRoomData interface {
GetProfileByLocalpart(context.Context, string) (*authtypes.Profile, error)
}
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom
type createRoomRequest struct { type createRoomRequest struct {
Invite []string `json:"invite"` Invite []string `json:"invite"`
@ -116,7 +120,7 @@ type fledglingEvent struct {
// CreateRoom implements /createRoom // CreateRoom implements /createRoom
func CreateRoom(req *http.Request, device *authtypes.Device, func CreateRoom(req *http.Request, device *authtypes.Device,
cfg config.Dendrite, producer *producers.RoomserverProducer, cfg config.Dendrite, producer *producers.RoomserverProducer,
accountDB *accounts.Database, aliasAPI api.RoomserverAliasAPI, accountDB createRoomData, aliasAPI api.RoomserverAliasAPI,
) util.JSONResponse { ) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs? // probably shouldn't be using pseudo-random strings, maybe GUIDs?
@ -128,7 +132,7 @@ func CreateRoom(req *http.Request, device *authtypes.Device,
// nolint: gocyclo // nolint: gocyclo
func createRoom(req *http.Request, device *authtypes.Device, func createRoom(req *http.Request, device *authtypes.Device,
cfg config.Dendrite, roomID string, producer *producers.RoomserverProducer, cfg config.Dendrite, roomID string, producer *producers.RoomserverProducer,
accountDB *accounts.Database, aliasAPI api.RoomserverAliasAPI, accountDB createRoomData, aliasAPI api.RoomserverAliasAPI,
) util.JSONResponse { ) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
userID := device.UserID userID := device.UserID

View file

@ -15,18 +15,25 @@
package routing package routing
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// interface to Database layer
type deviceData interface {
GetDevicesByLocalpart(context.Context, string) ([]authtypes.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
UpdateDevice(context.Context, string, string, *string) error
}
type deviceJSON struct { type deviceJSON struct {
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
@ -42,7 +49,7 @@ type deviceUpdateJSON struct {
// GetDeviceByID handles /device/{deviceID} // GetDeviceByID handles /device/{deviceID}
func GetDeviceByID( func GetDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB deviceData, device *authtypes.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
@ -72,7 +79,7 @@ func GetDeviceByID(
// GetDevicesByLocalpart handles /devices // GetDevicesByLocalpart handles /devices
func GetDevicesByLocalpart( func GetDevicesByLocalpart(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB deviceData, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -103,7 +110,7 @@ func GetDevicesByLocalpart(
// UpdateDeviceByID handles PUT on /devices/{deviceID} // UpdateDeviceByID handles PUT on /devices/{deviceID}
func UpdateDeviceByID( func UpdateDeviceByID(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB deviceData, device *authtypes.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != "PUT" { if req.Method != "PUT" {

View file

@ -15,12 +15,11 @@
package routing package routing
import ( import (
"context"
"encoding/json"
"net/http" "net/http"
"encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
@ -28,9 +27,15 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// interface to database layer
type filterData interface {
GetFilter(context.Context, string, string) ([]byte, error)
PutFilter(context.Context, string, []byte) (string, error)
}
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
func GetFilter( func GetFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string, req *http.Request, device *authtypes.Device, accountDB filterData, userID string, filterID string,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != http.MethodGet { if req.Method != http.MethodGet {
return util.JSONResponse{ return util.JSONResponse{
@ -77,7 +82,7 @@ type filterResponse struct {
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
func PutFilter( func PutFilter(
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, req *http.Request, device *authtypes.Device, accountDB filterData, userID string,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != http.MethodPost { if req.Method != http.MethodPost {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -15,13 +15,13 @@
package routing package routing
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
@ -33,6 +33,11 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// interface to Database
type joinRoomData interface {
GetProfileByLocalpart(context.Context, string) (*authtypes.Profile, error)
}
// JoinRoomByIDOrAlias implements the "/join/{roomIDOrAlias}" API. // JoinRoomByIDOrAlias implements the "/join/{roomIDOrAlias}" API.
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-join-roomidoralias // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-join-roomidoralias
func JoinRoomByIDOrAlias( func JoinRoomByIDOrAlias(
@ -45,7 +50,7 @@ func JoinRoomByIDOrAlias(
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
aliasAPI api.RoomserverAliasAPI, aliasAPI api.RoomserverAliasAPI,
keyRing gomatrixserverlib.KeyRing, keyRing gomatrixserverlib.KeyRing,
accountDB *accounts.Database, accountDB joinRoomData,
) util.JSONResponse { ) util.JSONResponse {
var content map[string]interface{} // must be a JSON object var content map[string]interface{} // must be a JSON object
if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil { if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil {

View file

@ -15,12 +15,12 @@
package routing package routing
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
@ -28,6 +28,16 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// interface to accounts database layer
type loginAccountsData interface {
GetAccountByPassword(context.Context, string, string) (*authtypes.Account, error)
}
// interface to devices database layer
type loginDevicesData interface {
CreateDevice(context.Context, string, *string, string, *string) (dev *authtypes.Device, returnErr error)
}
type loginFlows struct { type loginFlows struct {
Flows []flow `json:"flows"` Flows []flow `json:"flows"`
} }
@ -59,7 +69,7 @@ func passwordLogin() loginFlows {
// Login implements GET and POST /login // Login implements GET and POST /login
func Login( func Login(
req *http.Request, accountDB *accounts.Database, deviceDB *devices.Database, req *http.Request, accountDB loginAccountsData, deviceDB loginDevicesData,
cfg config.Dendrite, cfg config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == "GET" { // TODO: support other forms of login other than password, depending on config options if req.Method == "GET" { // TODO: support other forms of login other than password, depending on config options

View file

@ -15,19 +15,25 @@
package routing package routing
import ( import (
"context"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
//interface to devices database layer
type logoutDevicesData interface {
RemoveDevice(context.Context, string, string) error
RemoveAllDevices(context.Context, string) error
}
// Logout handles POST /logout // Logout handles POST /logout
func Logout( func Logout(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB logoutDevicesData, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != "POST" { if req.Method != "POST" {
return util.JSONResponse{ return util.JSONResponse{
@ -53,7 +59,7 @@ func Logout(
// LogoutAll handles POST /logout/all // LogoutAll handles POST /logout/all
func LogoutAll( func LogoutAll(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device, req *http.Request, deviceDB logoutDevicesData, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {

View file

@ -20,7 +20,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
@ -29,16 +28,19 @@ import (
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type membershipAccountsData interface {
GetProfileByLocalpart(context.Context, string) (*authtypes.Profile, error)
}
var errMissingUserID = errors.New("'user_id' must be supplied") var errMissingUserID = errors.New("'user_id' must be supplied")
// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite) // SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
// by building a m.room.member event then sending it to the room server // by building a m.room.member event then sending it to the room server
func SendMembership( func SendMembership(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB membershipAccountsData, device *authtypes.Device,
roomID string, membership string, cfg config.Dendrite, roomID string, membership string, cfg config.Dendrite,
queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, producer *producers.RoomserverProducer,
) util.JSONResponse { ) util.JSONResponse {
@ -111,7 +113,7 @@ func SendMembership(
func buildMembershipEvent( func buildMembershipEvent(
ctx context.Context, ctx context.Context,
body threepid.MembershipRequest, accountDB *accounts.Database, body threepid.MembershipRequest, accountDB membershipAccountsData,
device *authtypes.Device, membership string, roomID string, cfg config.Dendrite, device *authtypes.Device, membership string, roomID string, cfg config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
@ -156,7 +158,7 @@ func buildMembershipEvent(
// Returns an error if the retrieval failed or if the first parameter isn't a // Returns an error if the retrieval failed or if the first parameter isn't a
// valid Matrix ID. // valid Matrix ID.
func loadProfile( func loadProfile(
ctx context.Context, userID string, cfg config.Dendrite, accountDB *accounts.Database, ctx context.Context, userID string, cfg config.Dendrite, accountDB membershipAccountsData,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
localpart, serverName, err := gomatrixserverlib.SplitID('@', userID) localpart, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -19,7 +19,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
@ -27,13 +26,20 @@ import (
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// interface to accounts database layer
type profileAccountsData interface {
GetProfileByLocalpart(context.Context, string) (*authtypes.Profile, error)
SetAvatarURL(context.Context, string, string) error
GetMembershipsByLocalpart(context.Context, string) (memberships []authtypes.Membership, err error)
SetDisplayName(context.Context, string, string) error
}
// GetProfile implements GET /profile/{userID} // GetProfile implements GET /profile/{userID}
func GetProfile( func GetProfile(
req *http.Request, accountDB *accounts.Database, userID string, req *http.Request, accountDB profileAccountsData, userID string,
) util.JSONResponse { ) util.JSONResponse {
if req.Method != "GET" { if req.Method != "GET" {
return util.JSONResponse{ return util.JSONResponse{
@ -62,7 +68,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url // GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL( func GetAvatarURL(
req *http.Request, accountDB *accounts.Database, userID string, req *http.Request, accountDB profileAccountsData, userID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
@ -84,7 +90,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url // SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL( func SetAvatarURL(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB profileAccountsData, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -154,7 +160,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname // GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName( func GetDisplayName(
req *http.Request, accountDB *accounts.Database, userID string, req *http.Request, accountDB profileAccountsData, userID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
@ -176,7 +182,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName( func SetDisplayName(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB profileAccountsData, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -34,8 +34,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -56,6 +54,15 @@ var (
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-./]+$`) validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-./]+$`)
) )
type registerAccountsData interface {
CreateAccount(context.Context, string, string, string) (*authtypes.Account, error)
CheckAccountAvailability(context.Context, string) (bool, error)
}
type registerDevicesData interface {
CreateDevice(context.Context, string, *string, string, *string) (*authtypes.Device, error)
}
// registerRequest represents the submitted registration request. // registerRequest represents the submitted registration request.
// It can be broken down into 2 sections: the auth dictionary and registration parameters. // It can be broken down into 2 sections: the auth dictionary and registration parameters.
// Registration parameters vary depending on the request, and will need to remembered across // Registration parameters vary depending on the request, and will need to remembered across
@ -343,8 +350,8 @@ func validateApplicationService(
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
func Register( func Register(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB registerAccountsData,
deviceDB *devices.Database, deviceDB registerDevicesData,
cfg *config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
@ -408,8 +415,8 @@ func handleRegistrationFlow(
r registerRequest, r registerRequest,
sessionID string, sessionID string,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB registerAccountsData,
deviceDB *devices.Database, deviceDB registerDevicesData,
) util.JSONResponse { ) util.JSONResponse {
// TODO: Shared secret registration (create new user scripts) // TODO: Shared secret registration (create new user scripts)
// TODO: Enable registration config flag // TODO: Enable registration config flag
@ -490,8 +497,8 @@ func checkAndCompleteFlow(
r registerRequest, r registerRequest,
sessionID string, sessionID string,
cfg *config.Dendrite, cfg *config.Dendrite,
accountDB *accounts.Database, accountDB registerAccountsData,
deviceDB *devices.Database, deviceDB registerDevicesData,
) util.JSONResponse { ) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
@ -511,8 +518,8 @@ func checkAndCompleteFlow(
// LegacyRegister process register requests from the legacy v1 API // LegacyRegister process register requests from the legacy v1 API
func LegacyRegister( func LegacyRegister(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB registerAccountsData,
deviceDB *devices.Database, deviceDB registerDevicesData,
cfg *config.Dendrite, cfg *config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var r legacyRegisterRequest var r legacyRegisterRequest
@ -589,8 +596,8 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
accountDB *accounts.Database, accountDB registerAccountsData,
deviceDB *devices.Database, deviceDB registerDevicesData,
username, password, appserviceID string, username, password, appserviceID string,
displayName *string, displayName *string,
) util.JSONResponse { ) util.JSONResponse {
@ -751,7 +758,7 @@ type availableResponse struct {
// RegisterAvailable checks if the username is already taken or invalid. // RegisterAvailable checks if the username is already taken or invalid.
func RegisterAvailable( func RegisterAvailable(
req *http.Request, req *http.Request,
accountDB *accounts.Database, accountDB registerAccountsData,
) util.JSONResponse { ) util.JSONResponse {
username := req.URL.Query().Get("username") username := req.URL.Query().Get("username")

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"context"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -23,11 +24,18 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// interface to accounts database layer
type threepidAccountsData interface {
GetLocalpartForThreePID(context.Context, string, string) (string, error)
SaveThreePIDAssociation(context.Context, string, string, string) error
GetThreePIDsForLocalpart(context.Context, string) ([]authtypes.ThreePID, error)
RemoveThreePIDAssociation(context.Context, string, string) error
}
type reqTokenResponse struct { type reqTokenResponse struct {
SID string `json:"sid"` SID string `json:"sid"`
} }
@ -39,7 +47,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements: // RequestEmailToken implements:
// POST /account/3pid/email/requestToken // POST /account/3pid/email/requestToken
// POST /register/email/requestToken // POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg config.Dendrite) util.JSONResponse { func RequestEmailToken(req *http.Request, accountDB threepidAccountsData, cfg config.Dendrite) util.JSONResponse {
var body threepid.EmailAssociationRequest var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -82,7 +90,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
// CheckAndSave3PIDAssociation implements POST /account/3pid // CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation( func CheckAndSave3PIDAssociation(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB threepidAccountsData, device *authtypes.Device,
cfg config.Dendrite, cfg config.Dendrite,
) util.JSONResponse { ) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest var body threepid.EmailAssociationCheckRequest
@ -142,7 +150,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid // GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, accountDB *accounts.Database, device *authtypes.Device, req *http.Request, accountDB threepidAccountsData, device *authtypes.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -161,7 +169,7 @@ func GetAssociated3PIDs(
} }
// Forget3PID implements POST /account/3pid/delete // Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONResponse { func Forget3PID(req *http.Request, accountDB threepidAccountsData) util.JSONResponse {
var body authtypes.ThreePID var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr

View file

@ -15,12 +15,11 @@
package routing package routing
import ( import (
"net/http"
"crypto/hmac" "crypto/hmac"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"

View file

@ -25,7 +25,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
@ -33,6 +32,10 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type invitesAccountsData interface {
GetProfileByLocalpart(context.Context, string) (*authtypes.Profile, error)
}
// MembershipRequest represents the body of an incoming POST request // MembershipRequest represents the body of an incoming POST request
// on /rooms/{roomID}/(join|kick|ban|unban|leave|invite) // on /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
type MembershipRequest struct { type MembershipRequest struct {
@ -87,7 +90,7 @@ var (
func CheckAndProcessInvite( func CheckAndProcessInvite(
ctx context.Context, ctx context.Context,
device *authtypes.Device, body *MembershipRequest, cfg config.Dendrite, device *authtypes.Device, body *MembershipRequest, cfg config.Dendrite,
queryAPI api.RoomserverQueryAPI, db *accounts.Database, queryAPI api.RoomserverQueryAPI, db invitesAccountsData,
producer *producers.RoomserverProducer, membership string, roomID string, producer *producers.RoomserverProducer, membership string, roomID string,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
if membership != "invite" || (body.Address == "" && body.IDServer == "" && body.Medium == "") { if membership != "invite" || (body.Address == "" && body.IDServer == "" && body.Medium == "") {
@ -134,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed. // Returns an error if a check or a request failed.
func queryIDServer( func queryIDServer(
ctx context.Context, ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, db invitesAccountsData, cfg config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil { if err = isTrusted(body.IDServer, cfg); err != nil {
@ -203,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite( func queryIDServerStoreInvite(
ctx context.Context, ctx context.Context,
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, db invitesAccountsData, cfg config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) { ) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name // Retrieve the sender's profile to get their display name