Merge user API databases into one

This commit is contained in:
Neil Alexander 2022-02-14 12:00:27 +00:00
parent 5106cc807c
commit 7878012e7a
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
55 changed files with 590 additions and 822 deletions

View file

@ -23,7 +23,7 @@ import (
"errors" "errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -85,7 +85,7 @@ func RetrieveUserProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
asAPI AppServiceQueryAPI, asAPI AppServiceQueryAPI,
accountDB accounts.Database, accountDB userdb.Database,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -28,7 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -37,7 +37,7 @@ func AddPublicRoutes(
router *mux.Router, router *mux.Router,
synapseAdminRouter *mux.Router, synapseAdminRouter *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountsDB accounts.Database, accountsDB userdb.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
eduInputAPI eduServerAPI.EDUServerInputAPI, eduInputAPI eduServerAPI.EDUServerInputAPI,

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -137,7 +137,7 @@ type fledglingEvent struct {
func CreateRoom( func CreateRoom(
req *http.Request, device *api.Device, req *http.Request, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
@ -151,7 +151,7 @@ func CreateRoom(
func createRoom( func createRoom(
req *http.Request, device *api.Device, req *http.Request, device *api.Device,
cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI, roomID string,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomIDOrAlias string, roomIDOrAlias string,
) util.JSONResponse { ) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join. // Prepare to ask the roomserver to perform the room join.

View file

@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -36,7 +36,7 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys( func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive, req *http.Request, userInteractiveAuth *auth.UserInteractive,
keyserverAPI api.KeyInternalAPI, device *userapi.Device, keyserverAPI api.KeyInternalAPI, device *userapi.Device,
accountDB accounts.Database, cfg *config.ClientAPI, accountDB userdb.Database, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
uploadReq := &crossSigningRequest{} uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{} uploadRes := &api.PerformUploadDeviceKeysResponse{}

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -54,7 +54,7 @@ func passwordLogin() flows {
// Login implements GET and POST /login // Login implements GET and POST /login
func Login( func Login(
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { if req.Method == http.MethodGet {

View file

@ -30,7 +30,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -39,7 +39,7 @@ import (
var errMissingUserID = errors.New("'user_id' must be supplied") var errMissingUserID = errors.New("'user_id' must be supplied")
func SendBan( func SendBan(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -81,7 +81,7 @@ func SendBan(
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
} }
func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device, func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device,
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
roomVer gomatrixserverlib.RoomVersion, roomVer gomatrixserverlib.RoomVersion,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
} }
func SendKick( func SendKick(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -165,7 +165,7 @@ func SendKick(
} }
func SendUnban( func SendUnban(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -200,7 +200,7 @@ func SendUnban(
} }
func SendInvite( func SendInvite(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -271,7 +271,7 @@ func SendInvite(
func buildMembershipEvent( func buildMembershipEvent(
ctx context.Context, ctx context.Context,
targetUserID, reason string, accountDB accounts.Database, targetUserID, reason string, accountDB userdb.Database,
device *userapi.Device, device *userapi.Device,
membership, roomID string, isDirect bool, membership, roomID string, isDirect bool,
cfg *config.ClientAPI, evTime time.Time, cfg *config.ClientAPI, evTime time.Time,
@ -312,7 +312,7 @@ func loadProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, accountDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID) _, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -366,7 +366,7 @@ func checkAndProcessThreepid(
body *threepid.MembershipRequest, body *threepid.MembershipRequest,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomID string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) { ) (inviteStored bool, errRes *util.JSONResponse) {

View file

@ -9,7 +9,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -29,7 +29,7 @@ type newPasswordAuth struct {
func Password( func Password(
req *http.Request, req *http.Request,
userAPI api.UserInternalAPI, userAPI api.UserInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
device *api.Device, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -19,7 +19,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomIDOrAlias string, roomIDOrAlias string,
) util.JSONResponse { ) util.JSONResponse {
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
@ -82,7 +82,7 @@ func UnpeekRoomByID(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{ unpeekReq := roomserverAPI.PerformUnpeekRequest{

View file

@ -27,7 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID} // GetProfile implements GET /profile/{userID}
func GetProfile( func GetProfile(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -65,7 +65,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url // GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL( func GetAvatarURL(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -92,7 +92,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url // SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL( func SetAvatarURL(
req *http.Request, accountDB accounts.Database, req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -182,7 +182,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname // GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName( func GetDisplayName(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -209,7 +209,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName( func SetDisplayName(
req *http.Request, accountDB accounts.Database, req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -302,7 +302,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically // Returns an error when something goes wrong or specifically
// eventutil.ErrProfileNoExists when the profile doesn't exist. // eventutil.ErrProfileNoExists when the profile doesn't exist.
func getProfile( func getProfile(
ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI, ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,

View file

@ -38,7 +38,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -447,7 +447,7 @@ func validateApplicationService(
func Register( func Register(
req *http.Request, req *http.Request,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
var r registerRequest var r registerRequest
@ -891,7 +891,7 @@ type availableResponse struct {
func RegisterAvailable( func RegisterAvailable(
req *http.Request, req *http.Request,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, accountDB userdb.Database,
) util.JSONResponse { ) util.JSONResponse {
username := req.URL.Query().Get("username") username := req.URL.Query().Get("username")

View file

@ -34,7 +34,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -51,7 +51,7 @@ func Setup(
eduAPI eduServerAPI.EDUServerInputAPI, eduAPI eduServerAPI.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
accountDB accounts.Database, accountDB userdb.Database,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer, syncProducer *producers.SyncAPIProducer,

View file

@ -20,7 +20,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -33,7 +33,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer // sends the typing events to client API typingProducer
func SendTyping( func SendTyping(
req *http.Request, device *userapi.Device, roomID string, req *http.Request, device *userapi.Device, roomID string,
userID string, accountDB accounts.Database, userID string, accountDB userdb.Database,
eduAPI api.EDUServerInputAPI, eduAPI api.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -40,7 +40,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements: // RequestEmailToken implements:
// POST /account/3pid/email/requestToken // POST /account/3pid/email/requestToken
// POST /register/email/requestToken // POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse { func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse {
var body threepid.EmailAssociationRequest var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{ JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_IN_USE", ErrCode: "M_THREEPID_IN_USE",
Err: accounts.Err3PIDInUse.Error(), Err: userdb.Err3PIDInUse.Error(),
}, },
} }
} }
@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
// CheckAndSave3PIDAssociation implements POST /account/3pid // CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation( func CheckAndSave3PIDAssociation(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, accountDB userdb.Database, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest var body threepid.EmailAssociationCheckRequest
@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid // GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, accountDB userdb.Database, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -170,7 +170,7 @@ func GetAssociated3PIDs(
} }
// Forget3PID implements POST /account/3pid/delete // Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse { func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse {
var body authtypes.ThreePID var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr

View file

@ -29,7 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -87,7 +87,7 @@ var (
func CheckAndProcessInvite( func CheckAndProcessInvite(
ctx context.Context, ctx context.Context,
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, db accounts.Database, rsAPI api.RoomserverInternalAPI, db userdb.Database,
roomID string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed. // Returns an error if a check or a request failed.
func queryIDServer( func queryIDServer(
ctx context.Context, ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil { if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite( func queryIDServerStoreInvite(
ctx context.Context, ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) { ) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name // Retrieve the sender's profile to get their display name

View file

@ -22,10 +22,11 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"strings" "strings"
"time"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/term" "golang.org/x/term"
@ -74,9 +75,13 @@ func main() {
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ accountDB, err := userdb.NewDatabase(
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, &config.DatabaseOptions{
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
},
cfg.Global.ServerName, bcrypt.DefaultCost,
time.Duration(cfg.UserAPI.OpenIDTokenLifetimeMS)*time.Millisecond,
)
if err != nil { if err != nil {
logrus.Fatalln("Failed to connect to the database:", err.Error()) logrus.Fatalln("Failed to connect to the database:", err.Error())
} }

View file

@ -8,12 +8,11 @@ import (
"log" "log"
"os" "os"
pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/pressly/goose" "github.com/pressly/goose"
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -26,8 +25,7 @@ const (
RoomServer = "roomserver" RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver" SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi" SyncAPI = "syncapi"
UserAPIAccounts = "userapi_accounts" UserAPI = "userapi"
UserAPIDevices = "userapi_devices"
) )
var ( var (
@ -35,7 +33,7 @@ var (
flags = flag.NewFlagSet("goose", flag.ExitOnError) flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name") component = flags.String("component", "", "dendrite component name")
knownDBs = []string{ knownDBs = []string{
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices, AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
} }
) )
@ -143,18 +141,14 @@ Commands:
func loadSQLiteDeltas(component string) { func loadSQLiteDeltas(component string) {
switch component { switch component {
case UserAPIAccounts: case UserAPI:
slaccounts.LoadFromGoose() slusers.LoadFromGoose()
case UserAPIDevices:
sldevices.LoadFromGoose()
} }
} }
func loadPostgresDeltas(component string) { func loadPostgresDeltas(component string) {
switch component { switch component {
case UserAPIAccounts: case UserAPI:
pgaccounts.LoadFromGoose() pgusers.LoadFromGoose()
case UserAPIDevices:
pgdevices.LoadFromGoose()
} }
} }

View file

@ -38,7 +38,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -273,8 +273,13 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() accounts.Database { func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS) db, err := userdb.NewDatabase(
&b.Cfg.UserAPI.AccountDatabase,
b.Cfg.Global.ServerName,
b.Cfg.UserAPI.BCryptCost,
time.Duration(b.Cfg.UserAPI.OpenIDTokenLifetimeMS)*time.Millisecond,
)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db") logrus.WithError(err).Panicf("failed to connect to accounts db")
} }

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/dendrite/syncapi"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -38,7 +38,7 @@ import (
// all components of Dendrite, for use in monolith mode. // all components of Dendrite, for use in monolith mode.
type Monolith struct { type Monolith struct {
Config *config.Dendrite Config *config.Dendrite
AccountDB accounts.Database AccountDB userdb.Database
KeyRing *gomatrixserverlib.KeyRing KeyRing *gomatrixserverlib.KeyRing
Client *gomatrixserverlib.Client Client *gomatrixserverlib.Client
FedClient *gomatrixserverlib.FederationClient FedClient *gomatrixserverlib.FederationClient

View file

@ -27,16 +27,14 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type UserInternalAPI struct { type UserInternalAPI struct {
AccountDB accounts.Database DB storage.Database
DeviceDB devices.Database
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS // AppServices is the list of all registered AS
AppServices []config.ApplicationService AppServices []config.ApplicationService
@ -54,12 +52,12 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" { if req.DataType == "" {
return fmt.Errorf("data type must not be empty") return fmt.Errorf("data type must not be empty")
} }
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
if req.AccountType == api.AccountTypeGuest { if req.AccountType == api.AccountTypeGuest {
acc, err := a.AccountDB.CreateGuestAccount(ctx) acc, err := a.DB.CreateGuestAccount(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -67,7 +65,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc res.Account = acc
return nil return nil
} }
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID) acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {
@ -90,7 +88,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil return nil
} }
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err return err
} }
@ -100,7 +98,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
} }
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err return err
} }
res.PasswordUpdated = true res.PasswordUpdated = true
@ -113,7 +111,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID, "device_id": req.DeviceID,
"display_name": req.DeviceDisplayName, "display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation") }).Info("PerformDeviceCreation")
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil { if err != nil {
return err return err
} }
@ -138,12 +136,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 { if len(req.DeviceIDs) == 0 {
var devices []api.Device var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices { for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID) deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
} }
} else { } else {
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
} }
if err != nil { if err != nil {
return err return err
@ -197,7 +195,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil { if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
} }
return nil return nil
@ -209,7 +207,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err return err
} }
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
res.DeviceExists = false res.DeviceExists = false
return nil return nil
@ -224,7 +222,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil return nil
} }
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err return err
@ -262,7 +260,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
} }
prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -276,7 +274,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
} }
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil { if err != nil {
return err return err
} }
@ -285,7 +283,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
} }
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil { if err != nil {
return err return err
} }
@ -313,7 +311,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
} }
devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil { if err != nil {
return err return err
} }
@ -331,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
if req.DataType != "" { if req.DataType != "" {
var data json.RawMessage var data json.RawMessage
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil { if err != nil {
return err return err
} }
@ -349,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
return nil return nil
} }
global, rooms, err := a.AccountDB.GetAccountData(ctx, local) global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil { if err != nil {
return err return err
} }
@ -368,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil return nil
} }
device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -410,7 +408,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered // Verify that the user is registered
account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or // Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces // it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -428,7 +426,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
err := a.AccountDB.DeactivateAccount(ctx, req.Localpart) err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil res.AccountDeactivated = err == nil
return err return err
} }
@ -437,7 +435,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24) token := util.RandomString(24)
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID) exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{ res.Token = api.OpenIDToken{
Token: token, Token: token,
@ -450,7 +448,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token) openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil { if err != nil {
return err return err
} }
@ -472,7 +470,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
return nil return nil
} }
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err) res.Error = fmt.Sprintf("failed to delete backup: %s", err)
} }
@ -485,7 +483,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
// Create metadata // Create metadata
if req.Version == "" { if req.Version == "" {
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err) res.Error = fmt.Sprintf("failed to create backup: %s", err)
} }
@ -498,7 +496,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
// Update metadata // Update metadata
if len(req.Keys.Rooms) == 0 { if len(req.Keys.Rooms) == 0 {
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err) res.Error = fmt.Sprintf("failed to update backup: %s", err)
} }
@ -519,7 +517,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version // you can only upload keys for the CURRENT version
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err) res.Error = fmt.Sprintf("failed to query version: %s", err)
return return
@ -547,7 +545,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}) })
} }
} }
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err) res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return return
@ -557,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
} }
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version res.Version = version
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -573,14 +571,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted res.Exists = !deleted
if !req.ReturnKeys { if !req.ReturnKeys {
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err) res.Error = fmt.Sprintf("failed to count keys: %s", err)
} }
return return
} }
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err) res.Error = fmt.Sprintf("failed to query keys: %s", err)
return return

View file

@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
} }
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil { if err != nil {
return err return err
} }
@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
// PerformLoginTokenDeletion ensures the token doesn't exist. // PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion") util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
return a.DeviceDB.RemoveLoginToken(ctx, req.Token) return a.DB.RemoveLoginToken(ctx, req.Token)
} }
// QueryLoginToken returns the data associated with a login token. If // QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil. // the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil { if err != nil {
res.Data = nil res.Data = nil
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
} }
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil res.Data = nil
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil

View file

@ -1,52 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}

View file

@ -1,270 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,271 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
writer sqlutil.Writer
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, writer, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,42 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters. loginTokenLifetime determines how long a
// login token from CreateLoginToken is valid.
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,39 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(
dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName,
loginTokenLifetime time.Duration,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package storage
import ( import (
"context" "context"
@ -61,6 +61,35 @@ type Database interface {
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -5,13 +5,8 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) { func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
} }

View file

@ -16,7 +16,9 @@ package postgres
import ( import (
"context" "context"
"crypto/rand"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -27,7 +29,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -46,14 +48,23 @@ type Database struct {
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements keyBackupVersions keyBackupVersionStatements
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
keyBackups keyBackupStatements keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64 openIDTokenLifetimeMS time.Duration
} }
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
@ -73,6 +84,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }
@ -101,6 +113,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.keyBackups.prepare(db); err != nil { if err = d.keyBackups.prepare(db); err != nil {
return nil, err return nil, err
} }
if err = d.devices.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.loginTokens.prepare(db); err != nil {
return nil, err
}
return d, nil return d, nil
} }
@ -362,7 +380,7 @@ func (d *Database) CreateOpenIDToken(
ctx context.Context, ctx context.Context,
token, localpart string, token, localpart string,
) (int64, error) { ) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS expiresAtMS := time.Now().Add(d.openIDTokenLifetimeMS).UnixNano() / int64(time.Millisecond)
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
}) })
@ -518,3 +536,196 @@ func (d *Database) UpsertBackupKeys(
}) })
return return
} }
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -5,13 +5,8 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) { func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
} }

View file

@ -16,7 +16,9 @@ package sqlite3
import ( import (
"context" "context"
"crypto/rand"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -28,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -46,9 +48,12 @@ type Database struct {
openIDTokens tokenStatements openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements keyBackups keyBackupStatements
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64 openIDTokenLifetimeMS time.Duration
accountsMu sync.Mutex accountsMu sync.Mutex
profilesMu sync.Mutex profilesMu sync.Mutex
@ -56,8 +61,14 @@ type Database struct {
threepidsMu sync.Mutex threepidsMu sync.Mutex
} }
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
@ -77,6 +88,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }
@ -106,6 +118,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.keyBackups.prepare(db); err != nil { if err = d.keyBackups.prepare(db); err != nil {
return nil, err return nil, err
} }
if err = d.devices.prepare(db, d.writer, serverName); err != nil {
return nil, err
}
if err = d.loginTokens.prepare(db); err != nil {
return nil, err
}
return d, nil return d, nil
} }
@ -404,7 +422,7 @@ func (d *Database) CreateOpenIDToken(
ctx context.Context, ctx context.Context,
token, localpart string, token, localpart string,
) (int64, error) { ) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS expiresAtMS := time.Now().Add(d.openIDTokenLifetimeMS).UnixNano() / int64(time.Millisecond)
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
}) })
@ -561,3 +579,196 @@ func (d *Database) UpsertBackupKeys(
}) })
return return
} }
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -15,20 +15,21 @@
//go:build !wasm //go:build !wasm
// +build !wasm // +build !wasm
package accounts package storage
import ( import (
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" "github.com/matrix-org/dendrite/userapi/storage/postgres"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters // and sets postgres connection parameters
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS time.Duration) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)

View file

@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package storage
import ( import (
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -26,7 +27,7 @@ func NewDatabase(
dbProperties *config.DatabaseOptions, dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
bcryptCost int, bcryptCost int,
openIDTokenLifetimeMS int64, openIDTokenLifetimeMS time.Duration,
) (Database, error) { ) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():

View file

@ -23,8 +23,7 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -44,26 +43,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI { ) api.UserInternalAPI {
deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) db, err := storage.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, defaultLoginTokenLifetime)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db") logrus.WithError(err).Panicf("failed to connect to device db")
} }
return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) return newInternalAPI(db, cfg, appServices, keyAPI)
} }
func newInternalAPI( func newInternalAPI(
accountDB accounts.Database, db storage.Database,
deviceDB devices.Database,
cfg *config.UserAPI, cfg *config.UserAPI,
appServices []config.ApplicationService, appServices []config.ApplicationService,
keyAPI keyapi.KeyInternalAPI, keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI { ) api.UserInternalAPI {
return &internal.UserInternalAPI{ return &internal.UserInternalAPI{
AccountDB: accountDB, DB: db,
DeviceDB: deviceDB,
ServerName: cfg.Matrix.ServerName, ServerName: cfg.Matrix.ServerName,
AppServices: appServices, AppServices: appServices,
KeyAPI: keyAPI, KeyAPI: keyAPI,