Refactor appservice & client API to use userapi internal (#2290)

* Refactor user api internal

* Refactor clientapi to use internal userapi

* Use internal userapi instead of user DB directly

* Remove AccountDB dependency

* Fix linter issues

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
S7evinK 2022-03-24 22:45:44 +01:00 committed by GitHub
parent 8e76523b04
commit f2e550efd8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 682 additions and 239 deletions

View file

@ -19,11 +19,10 @@ package api
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
userdb "github.com/matrix-org/dendrite/userapi/storage" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -85,7 +84,7 @@ func RetrieveUserProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
asAPI AppServiceQueryAPI, asAPI AppServiceQueryAPI,
accountDB userdb.Database, profileAPI userapi.UserProfileAPI,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
@ -93,10 +92,17 @@ func RetrieveUserProfile(
} }
// Try to query the user from the local database // Try to query the user from the local database
profile, err := accountDB.GetProfileByLocalpart(ctx, localpart) res := &userapi.QueryProfileResponse{}
if err != nil && err != sql.ErrNoRows { err = profileAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: userID}, res)
if err != nil {
return nil, err return nil, err
} else if profile != nil { }
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: res.DisplayName,
AvatarURL: res.AvatarURL,
}
if res.UserExists {
return profile, nil return profile, nil
} }
@ -113,11 +119,15 @@ func RetrieveUserProfile(
} }
// Try to query the user from the local database again // Try to query the user from the local database again
profile, err = accountDB.GetProfileByLocalpart(ctx, localpart) err = profileAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: userID}, res)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// profile should not be nil at this point // profile should not be nil at this point
return profile, nil return &authtypes.Profile{
Localpart: localpart,
DisplayName: res.DisplayName,
AvatarURL: res.AvatarURL,
}, nil
} }

View file

@ -33,7 +33,7 @@ import (
// called after authorization has completed, with the result of the authorization. // called after authorization has completed, with the result of the authorization.
// If the final return value is non-nil, an error occurred and the cleanup function // If the final return value is non-nil, an error occurred and the cleanup function
// is nil. // is nil.
func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountDatabase, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserAccountAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) {
reqBytes, err := ioutil.ReadAll(r) reqBytes, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
err := &util.JSONResponse{ err := &util.JSONResponse{
@ -58,7 +58,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountData
switch header.Type { switch header.Type {
case authtypes.LoginTypePassword: case authtypes.LoginTypePassword:
typ = &LoginTypePassword{ typ = &LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword, GetAccountByPassword: useraccountAPI.QueryAccountByPassword,
Config: cfg, Config: cfg,
} }
case authtypes.LoginTypeToken: case authtypes.LoginTypeToken:

View file

@ -16,7 +16,6 @@ package auth
import ( import (
"context" "context"
"database/sql"
"net/http" "net/http"
"reflect" "reflect"
"strings" "strings"
@ -64,14 +63,13 @@ func TestLoginFromJSONReader(t *testing.T) {
} }
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {
var accountDB fakeAccountDB
var userAPI fakeUserInternalAPI var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
ServerName: serverName, ServerName: serverName,
}, },
} }
login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
if err != nil { if err != nil {
t.Fatalf("LoginFromJSONReader failed: %+v", err) t.Fatalf("LoginFromJSONReader failed: %+v", err)
} }
@ -143,14 +141,13 @@ func TestBadLoginFromJSONReader(t *testing.T) {
} }
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {
var accountDB fakeAccountDB
var userAPI fakeUserInternalAPI var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{ cfg := &config.ClientAPI{
Matrix: &config.Global{ Matrix: &config.Global{
ServerName: serverName, ServerName: serverName,
}, },
} }
_, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
if errRes == nil { if errRes == nil {
cleanup(ctx, nil) cleanup(ctx, nil)
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode)
@ -161,24 +158,22 @@ func TestBadLoginFromJSONReader(t *testing.T) {
} }
} }
type fakeAccountDB struct {
AccountDatabase
}
func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) {
if password == "invalidpassword" {
return nil, sql.ErrNoRows
}
return &uapi.Account{}, nil
}
type fakeUserInternalAPI struct { type fakeUserInternalAPI struct {
UserInternalAPIForLogin UserInternalAPIForLogin
uapi.UserAccountAPI
DeletedTokens []string DeletedTokens []string
} }
func (ua *fakeUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *uapi.QueryAccountByPasswordRequest, res *uapi.QueryAccountByPasswordResponse) error {
if req.PlaintextPassword == "invalidpassword" {
res.Account = nil
return nil
}
res.Exists = true
res.Account = &uapi.Account{}
return nil
}
func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error { func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error {
ua.DeletedTokens = append(ua.DeletedTokens, req.Token) ua.DeletedTokens = append(ua.DeletedTokens, req.Token)
return nil return nil

View file

@ -16,7 +16,6 @@ package auth
import ( import (
"context" "context"
"database/sql"
"net/http" "net/http"
"strings" "strings"
@ -29,7 +28,7 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type GetAccountByPassword func(ctx context.Context, localpart, password string) (*api.Account, error) type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error
type PasswordRequest struct { type PasswordRequest struct {
Login Login
@ -77,20 +76,34 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
} }
} }
// Squash username to all lowercase letters // Squash username to all lowercase letters
_, err = t.GetAccountByPassword(ctx, strings.ToLower(localpart), r.Password) res := &api.QueryAccountByPasswordResponse{}
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res)
if err != nil { if err != nil {
if err == sql.ErrNoRows { return nil, &util.JSONResponse{
_, err = t.GetAccountByPassword(ctx, localpart, r.Password) Code: http.StatusInternalServerError,
if err == nil { JSON: jsonerror.Unknown("unable to fetch account by password"),
return &r.Login, nil }
}
if !res.Exists {
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart,
PlaintextPassword: r.Password,
}, res)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"),
} }
} }
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user. // but that would leak the existence of the user.
if !res.Exists {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."),
} }
} }
}
return &r.Login, nil return &r.Login, nil
} }

View file

@ -110,9 +110,9 @@ type UserInteractive struct {
Sessions map[string][]string Sessions map[string][]string
} }
func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive { func NewUserInteractive(userAccountAPI api.UserAccountAPI, cfg *config.ClientAPI) *UserInteractive {
typePassword := &LoginTypePassword{ typePassword := &LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword, GetAccountByPassword: userAccountAPI.QueryAccountByPassword,
Config: cfg, Config: cfg,
} }
return &UserInteractive{ return &UserInteractive{

View file

@ -25,15 +25,25 @@ var (
) )
type fakeAccountDatabase struct { type fakeAccountDatabase struct {
AccountDatabase api.UserAccountAPI
} }
func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
acc, ok := lookup[localpart+" "+plaintextPassword] return nil
}
func (d *fakeAccountDatabase) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
return nil
}
func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
acc, ok := lookup[req.Localpart+" "+req.PlaintextPassword]
if !ok { if !ok {
return nil, fmt.Errorf("unknown user/password") return fmt.Errorf("unknown user/password")
} }
return acc, nil res.Account = acc
res.Exists = true
return nil
} }
func setup() *UserInteractive { func setup() *UserInteractive {

View file

@ -29,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -39,7 +38,6 @@ func AddPublicRoutes(
router *mux.Router, router *mux.Router,
synapseAdminRouter *mux.Router, synapseAdminRouter *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountsDB userdb.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
eduInputAPI eduServerAPI.EDUServerInputAPI, eduInputAPI eduServerAPI.EDUServerInputAPI,
@ -60,7 +58,7 @@ func AddPublicRoutes(
routing.Setup( routing.Setup(
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI, router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
accountsDB, userAPI, federation, userAPI, federation,
syncProducer, transactionsCache, fsAPI, keyAPI, syncProducer, transactionsCache, fsAPI, keyAPI,
extRoomsProvider, mscCfg, extRoomsProvider, mscCfg,
) )

View file

@ -31,7 +31,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -138,7 +137,7 @@ type fledglingEvent struct {
func CreateRoom( func CreateRoom(
req *http.Request, device *api.Device, req *http.Request, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, profileAPI api.UserProfileAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
var r createRoomRequest var r createRoomRequest
@ -156,7 +155,7 @@ func CreateRoom(
JSON: jsonerror.InvalidArgumentValue(err.Error()), JSON: jsonerror.InvalidArgumentValue(err.Error()),
} }
} }
return createRoom(req.Context(), r, device, cfg, accountDB, rsAPI, asAPI, evTime) return createRoom(req.Context(), r, device, cfg, profileAPI, rsAPI, asAPI, evTime)
} }
// createRoom implements /createRoom // createRoom implements /createRoom
@ -165,7 +164,7 @@ func createRoom(
ctx context.Context, ctx context.Context,
r createRoomRequest, device *api.Device, r createRoomRequest, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, profileAPI api.UserProfileAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
evTime time.Time, evTime time.Time,
) util.JSONResponse { ) util.JSONResponse {
@ -201,7 +200,7 @@ func createRoom(
"roomVersion": roomVersion, "roomVersion": roomVersion,
}).Info("Creating new room") }).Info("Creating new room")
profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -520,7 +519,7 @@ func createRoom(
for _, invitee := range r.Invite { for _, invitee := range r.Invite {
// Build the invite event. // Build the invite event.
inviteEvent, err := buildMembershipEvent( inviteEvent, err := buildMembershipEvent(
ctx, invitee, "", accountDB, device, gomatrixserverlib.Invite, ctx, invitee, "", profileAPI, device, gomatrixserverlib.Invite,
roomID, true, cfg, evTime, rsAPI, asAPI, roomID, true, cfg, evTime, rsAPI, asAPI,
) )
if err != nil { if err != nil {

View file

@ -15,7 +15,7 @@ import (
func Deactivate( func Deactivate(
req *http.Request, req *http.Request,
userInteractiveAuth *auth.UserInteractive, userInteractiveAuth *auth.UserInteractive,
userAPI api.UserInternalAPI, accountAPI api.UserAccountAPI,
deviceAPI *api.Device, deviceAPI *api.Device,
) util.JSONResponse { ) util.JSONResponse {
ctx := req.Context() ctx := req.Context()
@ -40,7 +40,7 @@ func Deactivate(
} }
var res api.PerformAccountDeactivationResponse var res api.PerformAccountDeactivationResponse
err = userAPI.PerformAccountDeactivation(ctx, &api.PerformAccountDeactivationRequest{ err = accountAPI.PerformAccountDeactivation(ctx, &api.PerformAccountDeactivationRequest{
Localpart: localpart, Localpart: localpart,
}, &res) }, &res)
if err != nil { if err != nil {

View file

@ -18,12 +18,10 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -32,7 +30,7 @@ func JoinRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB userdb.Database, profileAPI api.UserProfileAPI,
roomIDOrAlias string, roomIDOrAlias string,
) util.JSONResponse { ) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join. // Prepare to ask the roomserver to perform the room join.
@ -60,21 +58,25 @@ func JoinRoomByIDOrAlias(
_ = httputil.UnmarshalJSONRequest(req, &joinReq.Content) _ = httputil.UnmarshalJSONRequest(req, &joinReq.Content)
// Work out our localpart for the client profile request. // Work out our localpart for the client profile request.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
} else {
// Request our profile content to populate the request content with. // Request our profile content to populate the request content with.
var profile *authtypes.Profile res := &api.QueryProfileResponse{}
profile, err = accountDB.GetProfileByLocalpart(req.Context(), localpart) err := profileAPI.QueryProfile(req.Context(), &api.QueryProfileRequest{UserID: device.UserID}, res)
if err != nil { if err != nil || !res.UserExists {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed") if !res.UserExists {
} else { util.GetLogger(req.Context()).Error("Unable to query user profile, no profile found.")
joinReq.Content["displayname"] = profile.DisplayName return util.JSONResponse{
joinReq.Content["avatar_url"] = profile.AvatarURL Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("Unable to query user profile, no profile found."),
} }
} }
util.GetLogger(req.Context()).WithError(err).Error("UserProfileAPI.QueryProfile failed")
} else {
joinReq.Content["displayname"] = res.DisplayName
joinReq.Content["avatar_url"] = res.AvatarURL
}
// Ask the roomserver to perform the join. // Ask the roomserver to perform the join.
done := make(chan util.JSONResponse, 1) done := make(chan util.JSONResponse, 1)
go func() { go func() {

View file

@ -24,7 +24,6 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -36,7 +35,7 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys( func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive, req *http.Request, userInteractiveAuth *auth.UserInteractive,
keyserverAPI api.KeyInternalAPI, device *userapi.Device, keyserverAPI api.KeyInternalAPI, device *userapi.Device,
accountDB userdb.Database, cfg *config.ClientAPI, accountAPI userapi.UserAccountAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
uploadReq := &crossSigningRequest{} uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{} uploadRes := &api.PerformUploadDeviceKeysResponse{}
@ -64,7 +63,7 @@ func UploadCrossSigningDeviceKeys(
} }
} }
typePassword := auth.LoginTypePassword{ typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword, GetAccountByPassword: accountAPI.QueryAccountByPassword,
Config: cfg, Config: cfg,
} }
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -54,7 +53,7 @@ func passwordLogin() flows {
// Login implements GET and POST /login // Login implements GET and POST /login
func Login( func Login(
req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI, req *http.Request, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { if req.Method == http.MethodGet {
@ -64,7 +63,7 @@ func Login(
JSON: passwordLogin(), JSON: passwordLogin(),
} }
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, accountDB, userAPI, cfg) login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg)
if authErr != nil { if authErr != nil {
return *authErr return *authErr
} }

View file

@ -30,7 +30,6 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -39,7 +38,7 @@ import (
var errMissingUserID = errors.New("'user_id' must be supplied") var errMissingUserID = errors.New("'user_id' must be supplied")
func SendBan( func SendBan(
req *http.Request, accountDB userdb.Database, device *userapi.Device, req *http.Request, profileAPI userapi.UserProfileAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -78,16 +77,16 @@ func SendBan(
} }
} }
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) return sendMembership(req.Context(), profileAPI, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
} }
func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device, func sendMembership(ctx context.Context, profileAPI userapi.UserProfileAPI, device *userapi.Device,
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
roomVer gomatrixserverlib.RoomVersion, roomVer gomatrixserverlib.RoomVersion,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
event, err := buildMembershipEvent( event, err := buildMembershipEvent(
ctx, targetUserID, reason, accountDB, device, membership, ctx, targetUserID, reason, profileAPI, device, membership,
roomID, false, cfg, evTime, rsAPI, asAPI, roomID, false, cfg, evTime, rsAPI, asAPI,
) )
if err == errMissingUserID { if err == errMissingUserID {
@ -125,7 +124,7 @@ func sendMembership(ctx context.Context, accountDB userdb.Database, device *user
} }
func SendKick( func SendKick(
req *http.Request, accountDB userdb.Database, device *userapi.Device, req *http.Request, profileAPI userapi.UserProfileAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -161,11 +160,11 @@ func SendKick(
} }
} }
// TODO: should we be using SendLeave instead? // TODO: should we be using SendLeave instead?
return sendMembership(req.Context(), accountDB, device, roomID, "leave", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) return sendMembership(req.Context(), profileAPI, device, roomID, "leave", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
} }
func SendUnban( func SendUnban(
req *http.Request, accountDB userdb.Database, device *userapi.Device, req *http.Request, profileAPI userapi.UserProfileAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -196,11 +195,11 @@ func SendUnban(
} }
} }
// TODO: should we be using SendLeave instead? // TODO: should we be using SendLeave instead?
return sendMembership(req.Context(), accountDB, device, roomID, "leave", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) return sendMembership(req.Context(), profileAPI, device, roomID, "leave", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
} }
func SendInvite( func SendInvite(
req *http.Request, accountDB userdb.Database, device *userapi.Device, req *http.Request, profileAPI userapi.UserProfileAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -210,7 +209,7 @@ func SendInvite(
} }
inviteStored, jsonErrResp := checkAndProcessThreepid( inviteStored, jsonErrResp := checkAndProcessThreepid(
req, device, body, cfg, rsAPI, accountDB, roomID, evTime, req, device, body, cfg, rsAPI, profileAPI, roomID, evTime,
) )
if jsonErrResp != nil { if jsonErrResp != nil {
return *jsonErrResp return *jsonErrResp
@ -227,14 +226,14 @@ func SendInvite(
} }
// We already received the return value, so no need to check for an error here. // We already received the return value, so no need to check for an error here.
response, _ := sendInvite(req.Context(), accountDB, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime) response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime)
return response return response
} }
// sendInvite sends an invitation to a user. Returns a JSONResponse and an error // sendInvite sends an invitation to a user. Returns a JSONResponse and an error
func sendInvite( func sendInvite(
ctx context.Context, ctx context.Context,
accountDB userdb.Database, profileAPI userapi.UserProfileAPI,
device *userapi.Device, device *userapi.Device,
roomID, userID, reason string, roomID, userID, reason string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
@ -242,7 +241,7 @@ func sendInvite(
asAPI appserviceAPI.AppServiceQueryAPI, evTime time.Time, asAPI appserviceAPI.AppServiceQueryAPI, evTime time.Time,
) (util.JSONResponse, error) { ) (util.JSONResponse, error) {
event, err := buildMembershipEvent( event, err := buildMembershipEvent(
ctx, userID, reason, accountDB, device, "invite", ctx, userID, reason, profileAPI, device, "invite",
roomID, false, cfg, evTime, rsAPI, asAPI, roomID, false, cfg, evTime, rsAPI, asAPI,
) )
if err == errMissingUserID { if err == errMissingUserID {
@ -286,13 +285,13 @@ func sendInvite(
func buildMembershipEvent( func buildMembershipEvent(
ctx context.Context, ctx context.Context,
targetUserID, reason string, accountDB userdb.Database, targetUserID, reason string, profileAPI userapi.UserProfileAPI,
device *userapi.Device, device *userapi.Device,
membership, roomID string, isDirect bool, membership, roomID string, isDirect bool,
cfg *config.ClientAPI, evTime time.Time, cfg *config.ClientAPI, evTime time.Time,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
profile, err := loadProfile(ctx, targetUserID, cfg, accountDB, asAPI) profile, err := loadProfile(ctx, targetUserID, cfg, profileAPI, asAPI)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -327,7 +326,7 @@ func loadProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB userdb.Database, profileAPI userapi.UserProfileAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID) _, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -337,7 +336,7 @@ func loadProfile(
var profile *authtypes.Profile var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB) profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI)
} else { } else {
profile = &authtypes.Profile{} profile = &authtypes.Profile{}
} }
@ -381,13 +380,13 @@ func checkAndProcessThreepid(
body *threepid.MembershipRequest, body *threepid.MembershipRequest,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB userdb.Database, profileAPI userapi.UserProfileAPI,
roomID string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) { ) (inviteStored bool, errRes *util.JSONResponse) {
inviteStored, err := threepid.CheckAndProcessInvite( inviteStored, err := threepid.CheckAndProcessInvite(
req.Context(), device, body, cfg, rsAPI, accountDB, req.Context(), device, body, cfg, rsAPI, profileAPI,
roomID, evTime, roomID, evTime,
) )
if err == threepid.ErrMissingParameter { if err == threepid.ErrMissingParameter {

View file

@ -9,7 +9,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -30,7 +29,6 @@ type newPasswordAuth struct {
func Password( func Password(
req *http.Request, req *http.Request,
userAPI api.UserInternalAPI, userAPI api.UserInternalAPI,
accountDB userdb.Database,
device *api.Device, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -74,7 +72,7 @@ func Password(
// Check if the existing password is correct. // Check if the existing password is correct.
typePassword := auth.LoginTypePassword{ typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword, GetAccountByPassword: userAPI.QueryAccountByPassword,
Config: cfg, Config: cfg,
} }
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {

View file

@ -19,7 +19,6 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -28,7 +27,6 @@ func PeekRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB userdb.Database,
roomIDOrAlias string, roomIDOrAlias string,
) util.JSONResponse { ) util.JSONResponse {
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
@ -82,7 +80,6 @@ func UnpeekRoomByID(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB userdb.Database,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{ unpeekReq := roomserverAPI.PerformUnpeekRequest{

View file

@ -27,7 +27,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
@ -36,12 +35,12 @@ import (
// GetProfile implements GET /profile/{userID} // GetProfile implements GET /profile/{userID}
func GetProfile( func GetProfile(
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, req *http.Request, profileAPI userapi.UserProfileAPI, cfg *config.ClientAPI,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation)
if err != nil { if err != nil {
if err == eventutil.ErrProfileNoExists { if err == eventutil.ErrProfileNoExists {
return util.JSONResponse{ return util.JSONResponse{
@ -65,11 +64,11 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url // GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL( func GetAvatarURL(
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, req *http.Request, profileAPI userapi.UserProfileAPI, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation)
if err != nil { if err != nil {
if err == eventutil.ErrProfileNoExists { if err == eventutil.ErrProfileNoExists {
return util.JSONResponse{ return util.JSONResponse{
@ -92,7 +91,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url // SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL( func SetAvatarURL(
req *http.Request, accountDB userdb.Database, req *http.Request, profileAPI userapi.UserProfileAPI,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -127,22 +126,34 @@ func SetAvatarURL(
} }
} }
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) res := &userapi.QueryProfileResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
DisplayName: res.DisplayName,
AvatarURL: res.AvatarURL,
}
setRes := &userapi.PerformSetAvatarURLResponse{}
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
Localpart: localpart,
AvatarURL: r.AvatarURL,
}, setRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil { var roomsRes api.QueryRoomsForUserResponse
util.GetLogger(req.Context()).WithError(err).Error("accountDB.SetAvatarURL failed")
return jsonerror.InternalServerError()
}
var res api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID, UserID: device.UserID,
WantMembership: "join", WantMembership: "join",
}, &res) }, &roomsRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -155,7 +166,7 @@ func SetAvatarURL(
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -182,11 +193,11 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname // GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName( func GetDisplayName(
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, req *http.Request, profileAPI userapi.UserProfileAPI, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation)
if err != nil { if err != nil {
if err == eventutil.ErrProfileNoExists { if err == eventutil.ErrProfileNoExists {
return util.JSONResponse{ return util.JSONResponse{
@ -209,7 +220,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName( func SetDisplayName(
req *http.Request, accountDB userdb.Database, req *http.Request, profileAPI userapi.UserProfileAPI,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -244,14 +255,26 @@ func SetDisplayName(
} }
} }
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) pRes := &userapi.QueryProfileResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, pRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetProfileByLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
oldProfile := &authtypes.Profile{
Localpart: localpart,
DisplayName: pRes.DisplayName,
AvatarURL: pRes.AvatarURL,
}
if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil { err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
util.GetLogger(req.Context()).WithError(err).Error("accountDB.SetDisplayName failed") Localpart: localpart,
DisplayName: r.DisplayName,
}, &struct{}{})
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -302,7 +325,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically // Returns an error when something goes wrong or specifically
// eventutil.ErrProfileNoExists when the profile doesn't exist. // eventutil.ErrProfileNoExists when the profile doesn't exist.
func getProfile( func getProfile(
ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI, ctx context.Context, profileAPI userapi.UserProfileAPI, cfg *config.ClientAPI,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -331,7 +354,7 @@ func getProfile(
}, nil }, nil
} }
profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -44,7 +44,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
) )
var ( var (
@ -523,8 +522,7 @@ func validateApplicationService(
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
func Register( func Register(
req *http.Request, req *http.Request,
userAPI userapi.UserInternalAPI, userAPI userapi.UserRegisterAPI,
accountDB userdb.Database,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
var r registerRequest var r registerRequest
@ -552,13 +550,12 @@ func Register(
} }
// Auto generate a numeric username if r.Username is empty // Auto generate a numeric username if r.Username is empty
if r.Username == "" { if r.Username == "" {
id, err := accountDB.GetNewNumericLocalpart(req.Context()) res := &userapi.QueryNumericLocalpartResponse{}
if err != nil { if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetNewNumericLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
r.Username = strconv.FormatInt(res.ID, 10)
r.Username = strconv.FormatInt(id, 10)
} }
// Is this an appservice registration? It will be if the access // Is this an appservice registration? It will be if the access
@ -606,7 +603,7 @@ func handleGuestRegistration(
req *http.Request, req *http.Request,
r registerRequest, r registerRequest,
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserRegisterAPI,
) util.JSONResponse { ) util.JSONResponse {
if cfg.RegistrationDisabled || cfg.GuestsDisabled { if cfg.RegistrationDisabled || cfg.GuestsDisabled {
return util.JSONResponse{ return util.JSONResponse{
@ -671,7 +668,7 @@ func handleRegistrationFlow(
r registerRequest, r registerRequest,
sessionID string, sessionID string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserRegisterAPI,
accessToken string, accessToken string,
accessTokenErr error, accessTokenErr error,
) util.JSONResponse { ) util.JSONResponse {
@ -760,7 +757,7 @@ func handleApplicationServiceRegistration(
req *http.Request, req *http.Request,
r registerRequest, r registerRequest,
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserRegisterAPI,
) util.JSONResponse { ) util.JSONResponse {
// Check if we previously had issues extracting the access token from the // Check if we previously had issues extracting the access token from the
// request. // request.
@ -798,7 +795,7 @@ func checkAndCompleteFlow(
r registerRequest, r registerRequest,
sessionID string, sessionID string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserRegisterAPI,
) util.JSONResponse { ) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
@ -825,7 +822,7 @@ func checkAndCompleteFlow(
// not all // not all
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
userAPI userapi.UserInternalAPI, userAPI userapi.UserRegisterAPI,
username, password, appserviceID, ipAddr, userAgent, sessionID string, username, password, appserviceID, ipAddr, userAgent, sessionID string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
@ -991,7 +988,7 @@ type availableResponse struct {
func RegisterAvailable( func RegisterAvailable(
req *http.Request, req *http.Request,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB userdb.Database, registerAPI userapi.UserRegisterAPI,
) util.JSONResponse { ) util.JSONResponse {
username := req.URL.Query().Get("username") username := req.URL.Query().Get("username")
@ -1013,14 +1010,18 @@ func RegisterAvailable(
} }
} }
availability, availabilityErr := accountDB.CheckAccountAvailability(req.Context(), username) res := &userapi.QueryAccountAvailabilityResponse{}
if availabilityErr != nil { err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
Localpart: username,
}, res)
if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("failed to check availability: " + availabilityErr.Error()), JSON: jsonerror.Unknown("failed to check availability:" + err.Error()),
} }
} }
if !availability {
if !res.Available {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.UserInUse("Desired User ID is already taken."), JSON: jsonerror.UserInUse("Desired User ID is already taken."),

View file

@ -34,7 +34,6 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -51,7 +50,6 @@ func Setup(
eduAPI eduServerAPI.EDUServerInputAPI, eduAPI eduServerAPI.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
accountDB userdb.Database,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer, syncProducer *producers.SyncAPIProducer,
@ -62,7 +60,7 @@ func Setup(
mscCfg *config.MSCs, mscCfg *config.MSCs,
) { ) {
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
userInteractiveAuth := auth.NewUserInteractive(accountDB, cfg) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
unstableFeatures := map[string]bool{ unstableFeatures := map[string]bool{
"org.matrix.e2e_cross_signing": true, "org.matrix.e2e_cross_signing": true,
@ -120,7 +118,7 @@ func Setup(
// server notifications // server notifications
if cfg.Matrix.ServerNotices.Enabled { if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, accountDB, cfg) serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg)
if err != nil { if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices") logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
} }
@ -138,7 +136,7 @@ func Setup(
txnID := vars["txnID"] txnID := vars["txnID"]
return SendServerNotice( return SendServerNotice(
req, &cfg.Matrix.ServerNotices, req, &cfg.Matrix.ServerNotices,
cfg, userAPI, rsAPI, accountDB, asAPI, cfg, userAPI, rsAPI, asAPI,
device, serverNotificationSender, device, serverNotificationSender,
&txnID, transactionsCache, &txnID, transactionsCache,
) )
@ -153,7 +151,7 @@ func Setup(
} }
return SendServerNotice( return SendServerNotice(
req, &cfg.Matrix.ServerNotices, req, &cfg.Matrix.ServerNotices,
cfg, userAPI, rsAPI, accountDB, asAPI, cfg, userAPI, rsAPI, asAPI,
device, serverNotificationSender, device, serverNotificationSender,
nil, transactionsCache, nil, transactionsCache,
) )
@ -173,7 +171,7 @@ func Setup(
v3mux.Handle("/createRoom", v3mux.Handle("/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI) return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/join/{roomIDOrAlias}", v3mux.Handle("/join/{roomIDOrAlias}",
@ -186,7 +184,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return JoinRoomByIDOrAlias( return JoinRoomByIDOrAlias(
req, device, rsAPI, accountDB, vars["roomIDOrAlias"], req, device, rsAPI, userAPI, vars["roomIDOrAlias"],
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -202,7 +200,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return PeekRoomByIDOrAlias( return PeekRoomByIDOrAlias(
req, device, rsAPI, accountDB, vars["roomIDOrAlias"], req, device, rsAPI, vars["roomIDOrAlias"],
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -222,7 +220,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return JoinRoomByIDOrAlias( return JoinRoomByIDOrAlias(
req, device, rsAPI, accountDB, vars["roomID"], req, device, rsAPI, userAPI, vars["roomID"],
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -247,7 +245,7 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return UnpeekRoomByID( return UnpeekRoomByID(
req, device, rsAPI, accountDB, vars["roomID"], req, device, rsAPI, vars["roomID"],
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -257,7 +255,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/invite", v3mux.Handle("/rooms/{roomID}/invite",
@ -269,7 +267,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/kick", v3mux.Handle("/rooms/{roomID}/kick",
@ -278,7 +276,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/unban", v3mux.Handle("/rooms/{roomID}/unban",
@ -287,7 +285,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/send/{eventType}", v3mux.Handle("/rooms/{roomID}/send/{eventType}",
@ -383,14 +381,14 @@ func Setup(
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return Register(req, userAPI, accountDB, cfg) return Register(req, userAPI, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return RegisterAvailable(req, cfg, accountDB) return RegisterAvailable(req, cfg, userAPI)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
@ -468,7 +466,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI) return SendTyping(req, device, vars["roomID"], vars["userID"], eduAPI, rsAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/redact/{eventID}", v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
@ -529,7 +527,7 @@ func Setup(
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return Password(req, userAPI, accountDB, device, cfg) return Password(req, userAPI, device, cfg)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -549,7 +547,7 @@ func Setup(
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return Login(req, accountDB, userAPI, cfg) return Login(req, userAPI, cfg)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
@ -704,7 +702,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetProfile(req, accountDB, cfg, vars["userID"], asAPI, federation) return GetProfile(req, userAPI, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -714,7 +712,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAvatarURL(req, accountDB, cfg, vars["userID"], asAPI, federation) return GetAvatarURL(req, userAPI, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -727,7 +725,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetAvatarURL(req, accountDB, device, vars["userID"], cfg, rsAPI) return SetAvatarURL(req, userAPI, device, vars["userID"], cfg, rsAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
@ -739,7 +737,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetDisplayName(req, accountDB, cfg, vars["userID"], asAPI, federation) return GetDisplayName(req, userAPI, cfg, vars["userID"], asAPI, federation)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -752,7 +750,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SetDisplayName(req, accountDB, device, vars["userID"], cfg, rsAPI) return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
@ -760,25 +758,25 @@ func Setup(
v3mux.Handle("/account/3pid", v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAssociated3PIDs(req, accountDB, device) return GetAssociated3PIDs(req, userAPI, device)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/account/3pid", v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) return CheckAndSave3PIDAssociation(req, userAPI, device, cfg)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/account/3pid/delete", unstableMux.Handle("/account/3pid/delete",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Forget3PID(req, accountDB) return Forget3PID(req, userAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
return RequestEmailToken(req, accountDB, cfg) return RequestEmailToken(req, userAPI, cfg)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -1253,7 +1251,7 @@ func Setup(
// Cross-signing device keys // Cross-signing device keys
postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, accountDB, cfg) return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg)
}) })
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -20,7 +20,6 @@ import (
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -33,7 +32,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer // sends the typing events to client API typingProducer
func SendTyping( func SendTyping(
req *http.Request, device *userapi.Device, roomID string, req *http.Request, device *userapi.Device, roomID string,
userID string, accountDB userdb.Database, userID string,
eduAPI api.EDUServerInputAPI, eduAPI api.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -21,7 +21,6 @@ import (
"net/http" "net/http"
"time" "time"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/gomatrixserverlib/tokens"
@ -58,7 +57,6 @@ func SendServerNotice(
cfgClient *config.ClientAPI, cfgClient *config.ClientAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
accountsDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
device *userapi.Device, device *userapi.Device,
senderDevice *userapi.Device, senderDevice *userapi.Device,
@ -175,7 +173,7 @@ func SendServerNotice(
PowerLevelContentOverride: pl, PowerLevelContentOverride: pl,
} }
roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, accountsDB, rsAPI, asAPI, time.Now()) roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, userAPI, rsAPI, asAPI, time.Now())
switch data := roomRes.JSON.(type) { switch data := roomRes.JSON.(type) {
case createRoomResponse: case createRoomResponse:
@ -201,7 +199,7 @@ func SendServerNotice(
// we've found a room in common, check the membership // we've found a room in common, check the membership
roomID = commonRooms[0] roomID = commonRooms[0]
// re-invite the user // re-invite the user
res, err := sendInvite(ctx, accountsDB, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
if err != nil { if err != nil {
return res return res
} }
@ -284,7 +282,6 @@ func (r sendServerNoticeRequest) valid() (ok bool) {
func getSenderDevice( func getSenderDevice(
ctx context.Context, ctx context.Context,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
accountDB userdb.Database,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) (*userapi.Device, error) { ) (*userapi.Device, error) {
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
@ -299,8 +296,12 @@ func getSenderDevice(
} }
// set the avatarurl for the user // set the avatarurl for the user
if err = accountDB.SetAvatarURL(ctx, cfg.Matrix.ServerNotices.LocalPart, cfg.Matrix.ServerNotices.AvatarURL); err != nil { res := &userapi.PerformSetAvatarURLResponse{}
util.GetLogger(ctx).WithError(err).Error("accountDB.SetAvatarURL failed") if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
return nil, err return nil, err
} }

View file

@ -40,7 +40,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements: // RequestEmailToken implements:
// POST /account/3pid/email/requestToken // POST /account/3pid/email/requestToken
// POST /register/email/requestToken // POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse { func RequestEmailToken(req *http.Request, threePIDAPI api.UserThreePIDAPI, cfg *config.ClientAPI) util.JSONResponse {
var body threepid.EmailAssociationRequest var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -50,13 +50,18 @@ func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config
var err error var err error
// Check if the 3PID is already in use locally // Check if the 3PID is already in use locally
localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email") res := &api.QueryLocalpartForThreePIDResponse{}
err = threePIDAPI.QueryLocalpartForThreePID(req.Context(), &api.QueryLocalpartForThreePIDRequest{
ThreePID: body.Email,
Medium: "email",
}, res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetLocalpartForThreePID failed") util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.QueryLocalpartForThreePID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if len(localpart) > 0 { if len(res.Localpart) > 0 {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{ JSON: jsonerror.MatrixError{
@ -85,7 +90,7 @@ func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config
// CheckAndSave3PIDAssociation implements POST /account/3pid // CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation( func CheckAndSave3PIDAssociation(
req *http.Request, accountDB userdb.Database, device *api.Device, req *http.Request, threePIDAPI api.UserThreePIDAPI, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest var body threepid.EmailAssociationCheckRequest
@ -136,8 +141,12 @@ func CheckAndSave3PIDAssociation(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil { if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
util.GetLogger(req.Context()).WithError(err).Error("accountsDB.SaveThreePIDAssociation failed") ThreePID: address,
Localpart: localpart,
Medium: medium,
}, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -149,7 +158,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid // GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, accountDB userdb.Database, device *api.Device, req *http.Request, threepidAPI api.UserThreePIDAPI, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -157,27 +166,30 @@ func GetAssociated3PIDs(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart) res := &api.QueryThreePIDsForLocalpartResponse{}
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
Localpart: localpart,
}, res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetThreePIDsForLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: threePIDsResponse{threepids}, JSON: threePIDsResponse{res.ThreePIDs},
} }
} }
// Forget3PID implements POST /account/3pid/delete // Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse { func Forget3PID(req *http.Request, threepidAPI api.UserThreePIDAPI) util.JSONResponse {
var body authtypes.ThreePID var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
} }
if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil { if err := threepidAPI.PerformForgetThreePID(req.Context(), &api.PerformForgetThreePIDRequest{}, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.RemoveThreePIDAssociation failed") util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.PerformForgetThreePID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -29,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -87,7 +86,7 @@ var (
func CheckAndProcessInvite( func CheckAndProcessInvite(
ctx context.Context, ctx context.Context,
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, db userdb.Database, rsAPI api.RoomserverInternalAPI, db userapi.UserProfileAPI,
roomID string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +136,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed. // Returns an error if a check or a request failed.
func queryIDServer( func queryIDServer(
ctx context.Context, ctx context.Context,
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, db userapi.UserProfileAPI, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil { if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +205,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite( func queryIDServerStoreInvite(
ctx context.Context, ctx context.Context,
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, db userapi.UserProfileAPI, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) { ) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name // Retrieve the sender's profile to get their display name
@ -217,10 +216,17 @@ func queryIDServerStoreInvite(
var profile *authtypes.Profile var profile *authtypes.Profile
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
profile, err = db.GetProfileByLocalpart(ctx, localpart) res := &userapi.QueryProfileResponse{}
err = db.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res)
if err != nil { if err != nil {
return nil, err return nil, err
} }
profile = &authtypes.Profile{
Localpart: localpart,
DisplayName: res.DisplayName,
AvatarURL: res.AvatarURL,
}
} else { } else {
profile = &authtypes.Profile{} profile = &authtypes.Profile{}
} }

View file

@ -22,7 +22,6 @@ import (
) )
func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
accountDB := base.CreateAccountsDB()
federation := base.CreateFederationClient() federation := base.CreateFederationClient()
asQuery := base.AppserviceHTTPClient() asQuery := base.AppserviceHTTPClient()
@ -34,7 +33,7 @@ func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
base.ProcessContext, base.PublicClientAPIMux, base.SynapseAdminMux, &base.Cfg.ClientAPI, base.ProcessContext, base.PublicClientAPIMux, base.SynapseAdminMux, &base.Cfg.ClientAPI,
accountDB, federation, rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, federation, rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI,
keyAPI, nil, &cfg.MSCs, keyAPI, nil, &cfg.MSCs,
) )

View file

@ -57,7 +57,7 @@ type Monolith struct {
// AddAllPublicRoutes attaches all public paths to the given router // AddAllPublicRoutes attaches all public paths to the given router
func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux *mux.Router) { func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux *mux.Router) {
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
process, csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB, process, csMux, synapseMux, &m.Config.ClientAPI,
m.FedClient, m.RoomserverAPI, m.FedClient, m.RoomserverAPI,
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
m.FederationAPI, m.UserAPI, m.KeyAPI, m.FederationAPI, m.UserAPI, m.KeyAPI,

View file

@ -27,16 +27,16 @@ import (
// UserInternalAPI is the internal API for information about users and devices. // UserInternalAPI is the internal API for information about users and devices.
type UserInternalAPI interface { type UserInternalAPI interface {
LoginTokenInternalAPI LoginTokenInternalAPI
UserProfileAPI
UserRegisterAPI
UserAccountAPI
UserThreePIDAPI
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
@ -44,18 +44,47 @@ type UserInternalAPI interface {
PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
} }
// UserProfileAPI provides functions for getting user profiles
type UserProfileAPI interface {
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error
}
// UserRegisterAPI defines functions for registering accounts
type UserRegisterAPI interface {
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
}
// UserAccountAPI defines functions for changing an account
type UserAccountAPI interface {
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
QueryAccountByPassword(ctx context.Context, req *QueryAccountByPasswordRequest, res *QueryAccountByPasswordResponse) error
}
// UserThreePIDAPI defines functions for 3PID
type UserThreePIDAPI interface {
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
}
type PerformKeyBackupRequest struct { type PerformKeyBackupRequest struct {
UserID string UserID string
Version string // optional if modifying a key backup Version string // optional if modifying a key backup
@ -507,3 +536,55 @@ type Notification struct {
RoomID string `json:"room_id"` // Required. RoomID string `json:"room_id"` // Required.
TS gomatrixserverlib.Timestamp `json:"ts"` // Required. TS gomatrixserverlib.Timestamp `json:"ts"` // Required.
} }
type PerformSetAvatarURLRequest struct {
Localpart, AvatarURL string
}
type PerformSetAvatarURLResponse struct{}
type QueryNumericLocalpartResponse struct {
ID int64
}
type QueryAccountAvailabilityRequest struct {
Localpart string
}
type QueryAccountAvailabilityResponse struct {
Available bool
}
type QueryAccountByPasswordRequest struct {
Localpart, PlaintextPassword string
}
type QueryAccountByPasswordResponse struct {
Account *Account
Exists bool
}
type PerformUpdateDisplayNameRequest struct {
Localpart, DisplayName string
}
type QueryLocalpartForThreePIDRequest struct {
ThreePID, Medium string
}
type QueryLocalpartForThreePIDResponse struct {
Localpart string
}
type QueryThreePIDsForLocalpartRequest struct {
Localpart string
}
type QueryThreePIDsForLocalpartResponse struct {
ThreePIDs []authtypes.ThreePID
}
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
type PerformSaveThreePIDAssociationRequest struct {
ThreePID, Localpart, Medium string
}

View file

@ -149,6 +149,60 @@ func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *Quer
return err return err
} }
func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error {
err := t.Impl.SetAvatarURL(ctx, req, res)
util.GetLogger(ctx).Infof("SetAvatarURL req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
err := t.Impl.QueryNumericLocalpart(ctx, res)
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
return err
}
func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error {
err := t.Impl.QueryAccountAvailability(ctx, req, res)
util.GetLogger(ctx).Infof("QueryAccountAvailability req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error {
err := t.Impl.SetDisplayName(ctx, req, res)
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryAccountByPassword(ctx context.Context, req *QueryAccountByPasswordRequest, res *QueryAccountByPasswordResponse) error {
err := t.Impl.QueryAccountByPassword(ctx, req, res)
util.GetLogger(ctx).Infof("QueryAccountByPassword req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error {
err := t.Impl.QueryLocalpartForThreePID(ctx, req, res)
util.GetLogger(ctx).Infof("QueryLocalpartForThreePID req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error {
err := t.Impl.QueryThreePIDsForLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryThreePIDsForLocalpart req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error {
err := t.Impl.PerformForgetThreePID(ctx, req, res)
util.GetLogger(ctx).Infof("PerformForgetThreePID req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error {
err := t.Impl.PerformSaveThreePIDAssociation(ctx, req, res)
util.GetLogger(ctx).Infof("PerformSaveThreePIDAssociation req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)
if err != nil { if err != nil {

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
@ -761,4 +762,71 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
return nil return nil
} }
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
}
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
id, err := a.DB.GetNewNumericLocalpart(ctx)
if err != nil {
return err
}
res.ID = id
return nil
}
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
_, err := a.DB.CheckAccountAvailability(ctx, req.Localpart)
if err == sql.ErrNoRows {
res.Available = true
return nil
}
res.Available = false
return err
}
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
switch err {
case sql.ErrNoRows: // user does not exist
return nil
case bcrypt.ErrMismatchedHashAndPassword: // user exists, but password doesn't match
return nil
default:
res.Exists = true
res.Account = acc
return nil
}
}
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error {
return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
if err != nil {
return err
}
res.Localpart = localpart
return nil
}
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
if err != nil {
return err
}
res.ThreePIDs = r
return nil
}
func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error {
return a.DB.RemoveThreePIDAssociation(ctx, req.ThreePID, req.Medium)
}
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
}
const pushRulesAccountDataType = "m.push_rules" const pushRulesAccountDataType = "m.push_rules"

View file

@ -40,6 +40,10 @@ const (
PerformPusherSetPath = "/pushserver/performPusherSet" PerformPusherSetPath = "/pushserver/performPusherSet"
PerformPusherDeletionPath = "/pushserver/performPusherDeletion" PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
PerformPushRulesPutPath = "/pushserver/performPushRulesPut" PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
PerformSetAvatarURLPath = "/userapi/performSetAvatarURL"
PerformSetDisplayNamePath = "/userapi/performSetDisplayName"
PerformForgetThreePIDPath = "/userapi/performForgetThreePID"
PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation"
QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
@ -52,6 +56,11 @@ const (
QueryPushersPath = "/pushserver/queryPushers" QueryPushersPath = "/pushserver/queryPushers"
QueryPushRulesPath = "/pushserver/queryPushRules" QueryPushRulesPath = "/pushserver/queryPushRules"
QueryNotificationsPath = "/pushserver/queryNotifications" QueryNotificationsPath = "/pushserver/queryNotifications"
QueryNumericLocalpartPath = "/userapi/queryNumericLocalpart"
QueryAccountAvailabilityPath = "/userapi/queryAccountAvailability"
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -310,3 +319,75 @@ func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.Query
apiURL := h.apiURL + QueryPushRulesPath apiURL := h.apiURL + QueryPushRulesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetAvatarURLPath)
defer span.Finish()
apiURL := h.apiURL + PerformSetAvatarURLPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryNumericLocalpartPath)
defer span.Finish()
apiURL := h.apiURL + QueryNumericLocalpartPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, struct{}{}, res)
}
func (h *httpUserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountAvailabilityPath)
defer span.Finish()
apiURL := h.apiURL + QueryAccountAvailabilityPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountByPasswordPath)
defer span.Finish()
apiURL := h.apiURL + QueryAccountByPasswordPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetDisplayNamePath)
defer span.Finish()
apiURL := h.apiURL + PerformSetDisplayNamePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForThreePIDPath)
defer span.Finish()
apiURL := h.apiURL + QueryLocalpartForThreePIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryThreePIDsForLocalpartPath)
defer span.Finish()
apiURL := h.apiURL + QueryThreePIDsForLocalpartPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetThreePIDPath)
defer span.Finish()
apiURL := h.apiURL + PerformForgetThreePIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveThreePIDAssociationPath)
defer span.Finish()
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -347,4 +347,101 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(PerformSetAvatarURLPath,
httputil.MakeInternalAPI("performSetAvatarURL", func(req *http.Request) util.JSONResponse {
request := api.PerformSetAvatarURLRequest{}
response := api.PerformSetAvatarURLResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.SetAvatarURL(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryNumericLocalpartPath,
httputil.MakeInternalAPI("queryNumericLocalpart", func(req *http.Request) util.JSONResponse {
response := api.QueryNumericLocalpartResponse{}
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryAccountByPasswordPath,
httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse {
request := api.QueryAccountByPasswordRequest{}
response := api.QueryAccountByPasswordResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccountByPassword(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformSetDisplayNamePath,
httputil.MakeInternalAPI("performSetDisplayName", func(req *http.Request) util.JSONResponse {
request := api.PerformUpdateDisplayNameRequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.SetDisplayName(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
)
internalAPIMux.Handle(QueryLocalpartForThreePIDPath,
httputil.MakeInternalAPI("queryLocalpartForThreePID", func(req *http.Request) util.JSONResponse {
request := api.QueryLocalpartForThreePIDRequest{}
response := api.QueryLocalpartForThreePIDResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryLocalpartForThreePID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryThreePIDsForLocalpartPath,
httputil.MakeInternalAPI("queryThreePIDsForLocalpart", func(req *http.Request) util.JSONResponse {
request := api.QueryThreePIDsForLocalpartRequest{}
response := api.QueryThreePIDsForLocalpartResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryThreePIDsForLocalpart(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformForgetThreePIDPath,
httputil.MakeInternalAPI("performForgetThreePID", func(req *http.Request) util.JSONResponse {
request := api.PerformForgetThreePIDRequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformForgetThreePID(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
)
internalAPIMux.Handle(PerformSaveThreePIDAssociationPath,
httputil.MakeInternalAPI("performSaveThreePIDAssociation", func(req *http.Request) util.JSONResponse {
request := api.PerformSaveThreePIDAssociationRequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformSaveThreePIDAssociation(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
)
} }

View file

@ -24,12 +24,17 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
) )
type Database interface { type Profile interface {
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error
}
type Database interface {
Profile
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
@ -48,7 +53,6 @@ type Database interface {
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)