mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Merge branch 'main' into implement-push-notifications
This commit is contained in:
commit
857b75d66e
|
|
@ -23,7 +23,7 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ func RetrieveUserProfile(
|
|||
ctx context.Context,
|
||||
userID string,
|
||||
asAPI AppServiceQueryAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
) (*authtypes.Profile, error) {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -283,8 +283,7 @@ func (m *DendriteMonolith) Start() {
|
|||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
|
||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
|
||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
|
||||
|
|
|
|||
|
|
@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() {
|
|||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory))
|
||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ func AddPublicRoutes(
|
|||
router *mux.Router,
|
||||
synapseAdminRouter *mux.Router,
|
||||
cfg *config.ClientAPI,
|
||||
accountsDB accounts.Database,
|
||||
accountsDB userdb.Database,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
eduInputAPI eduServerAPI.EDUServerInputAPI,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
|
@ -137,7 +137,7 @@ type fledglingEvent struct {
|
|||
func CreateRoom(
|
||||
req *http.Request, device *api.Device,
|
||||
cfg *config.ClientAPI,
|
||||
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
) util.JSONResponse {
|
||||
// TODO (#267): Check room ID doesn't clash with an existing one, and we
|
||||
|
|
@ -151,7 +151,7 @@ func CreateRoom(
|
|||
func createRoom(
|
||||
req *http.Request, device *api.Device,
|
||||
cfg *config.ClientAPI, roomID string,
|
||||
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
|
@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias(
|
|||
req *http.Request,
|
||||
device *api.Device,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
roomIDOrAlias string,
|
||||
) util.JSONResponse {
|
||||
// Prepare to ask the roomserver to perform the room join.
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ type crossSigningRequest struct {
|
|||
func UploadCrossSigningDeviceKeys(
|
||||
req *http.Request, userInteractiveAuth *auth.UserInteractive,
|
||||
keyserverAPI api.KeyInternalAPI, device *userapi.Device,
|
||||
accountDB accounts.Database, cfg *config.ClientAPI,
|
||||
accountDB userdb.Database, cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
uploadReq := &crossSigningRequest{}
|
||||
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
|
@ -54,7 +54,7 @@ func passwordLogin() flows {
|
|||
|
||||
// Login implements GET and POST /login
|
||||
func Login(
|
||||
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
||||
req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
if req.Method == http.MethodGet {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import (
|
|||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -39,7 +39,7 @@ import (
|
|||
var errMissingUserID = errors.New("'user_id' must be supplied")
|
||||
|
||||
func SendBan(
|
||||
req *http.Request, accountDB accounts.Database, device *userapi.Device,
|
||||
req *http.Request, accountDB userdb.Database, device *userapi.Device,
|
||||
roomID string, cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -81,7 +81,7 @@ func SendBan(
|
|||
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
|
||||
}
|
||||
|
||||
func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device,
|
||||
func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device,
|
||||
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
|
||||
roomVer gomatrixserverlib.RoomVersion,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
|
||||
|
|
@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
|
|||
}
|
||||
|
||||
func SendKick(
|
||||
req *http.Request, accountDB accounts.Database, device *userapi.Device,
|
||||
req *http.Request, accountDB userdb.Database, device *userapi.Device,
|
||||
roomID string, cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -165,7 +165,7 @@ func SendKick(
|
|||
}
|
||||
|
||||
func SendUnban(
|
||||
req *http.Request, accountDB accounts.Database, device *userapi.Device,
|
||||
req *http.Request, accountDB userdb.Database, device *userapi.Device,
|
||||
roomID string, cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -200,7 +200,7 @@ func SendUnban(
|
|||
}
|
||||
|
||||
func SendInvite(
|
||||
req *http.Request, accountDB accounts.Database, device *userapi.Device,
|
||||
req *http.Request, accountDB userdb.Database, device *userapi.Device,
|
||||
roomID string, cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -271,7 +271,7 @@ func SendInvite(
|
|||
|
||||
func buildMembershipEvent(
|
||||
ctx context.Context,
|
||||
targetUserID, reason string, accountDB accounts.Database,
|
||||
targetUserID, reason string, accountDB userdb.Database,
|
||||
device *userapi.Device,
|
||||
membership, roomID string, isDirect bool,
|
||||
cfg *config.ClientAPI, evTime time.Time,
|
||||
|
|
@ -312,7 +312,7 @@ func loadProfile(
|
|||
ctx context.Context,
|
||||
userID string,
|
||||
cfg *config.ClientAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
) (*authtypes.Profile, error) {
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||
|
|
@ -366,7 +366,7 @@ func checkAndProcessThreepid(
|
|||
body *threepid.MembershipRequest,
|
||||
cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
roomID string,
|
||||
evTime time.Time,
|
||||
) (inviteStored bool, errRes *util.JSONResponse) {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import (
|
|||
pushserverapi "github.com/matrix-org/dendrite/pushserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
|
@ -32,7 +32,7 @@ func Password(
|
|||
req *http.Request,
|
||||
psAPI pushserverapi.PushserverInternalAPI,
|
||||
userAPI api.UserInternalAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
device *api.Device,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import (
|
|||
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
|
@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias(
|
|||
req *http.Request,
|
||||
device *api.Device,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
roomIDOrAlias string,
|
||||
) util.JSONResponse {
|
||||
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
|
||||
|
|
@ -82,7 +82,7 @@ func UnpeekRoomByID(
|
|||
req *http.Request,
|
||||
device *api.Device,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
roomID string,
|
||||
) util.JSONResponse {
|
||||
unpeekReq := roomserverAPI.PerformUnpeekRequest{
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/gomatrix"
|
||||
|
|
@ -36,7 +36,7 @@ import (
|
|||
|
||||
// GetProfile implements GET /profile/{userID}
|
||||
func GetProfile(
|
||||
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
|
||||
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
|
||||
userID string,
|
||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
|
|
@ -65,7 +65,7 @@ func GetProfile(
|
|||
|
||||
// GetAvatarURL implements GET /profile/{userID}/avatar_url
|
||||
func GetAvatarURL(
|
||||
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
|
||||
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
|
||||
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -92,7 +92,7 @@ func GetAvatarURL(
|
|||
|
||||
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
|
||||
func SetAvatarURL(
|
||||
req *http.Request, accountDB accounts.Database,
|
||||
req *http.Request, accountDB userdb.Database,
|
||||
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
|
||||
) util.JSONResponse {
|
||||
if userID != device.UserID {
|
||||
|
|
@ -182,7 +182,7 @@ func SetAvatarURL(
|
|||
|
||||
// GetDisplayName implements GET /profile/{userID}/displayname
|
||||
func GetDisplayName(
|
||||
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
|
||||
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
|
||||
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -209,7 +209,7 @@ func GetDisplayName(
|
|||
|
||||
// SetDisplayName implements PUT /profile/{userID}/displayname
|
||||
func SetDisplayName(
|
||||
req *http.Request, accountDB accounts.Database,
|
||||
req *http.Request, accountDB userdb.Database,
|
||||
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
|
||||
) util.JSONResponse {
|
||||
if userID != device.UserID {
|
||||
|
|
@ -302,7 +302,7 @@ func SetDisplayName(
|
|||
// Returns an error when something goes wrong or specifically
|
||||
// eventutil.ErrProfileNoExists when the profile doesn't exist.
|
||||
func getProfile(
|
||||
ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI,
|
||||
ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI,
|
||||
userID string,
|
||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -448,7 +448,7 @@ func validateApplicationService(
|
|||
func Register(
|
||||
req *http.Request,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
var r registerRequest
|
||||
|
|
@ -532,6 +532,13 @@ func handleGuestRegistration(
|
|||
cfg *config.ClientAPI,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
) util.JSONResponse {
|
||||
if cfg.RegistrationDisabled || cfg.GuestsDisabled {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("Guest registration is disabled"),
|
||||
}
|
||||
}
|
||||
|
||||
var res userapi.PerformAccountCreationResponse
|
||||
err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
|
||||
AccountType: userapi.AccountTypeGuest,
|
||||
|
|
@ -892,7 +899,7 @@ type availableResponse struct {
|
|||
func RegisterAvailable(
|
||||
req *http.Request,
|
||||
cfg *config.ClientAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
) util.JSONResponse {
|
||||
username := req.URL.Query().Get("username")
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import (
|
|||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
|
@ -51,7 +51,7 @@ func Setup(
|
|||
eduAPI eduServerAPI.EDUServerInputAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
accountDB accounts.Database,
|
||||
accountDB userdb.Database,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
syncProducer *producers.SyncAPIProducer,
|
||||
|
|
@ -118,15 +118,22 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
}
|
||||
|
||||
r0mux := publicAPIMux.PathPrefix("/r0").Subrouter()
|
||||
// You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables.
|
||||
// So make a named path variable called 'apiversion' (which we will never read in handlers) and then do
|
||||
// (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group
|
||||
// from a match (gorilla/mux exposes no way to do this) so it demands you make it a non-capturing group
|
||||
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
|
||||
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
|
||||
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
|
||||
|
||||
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
|
||||
|
||||
r0mux.Handle("/createRoom",
|
||||
v3mux.Handle("/createRoom",
|
||||
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/join/{roomIDOrAlias}",
|
||||
v3mux.Handle("/join/{roomIDOrAlias}",
|
||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -142,7 +149,7 @@ func Setup(
|
|||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
if mscCfg.Enabled("msc2753") {
|
||||
r0mux.Handle("/peek/{roomIDOrAlias}",
|
||||
v3mux.Handle("/peek/{roomIDOrAlias}",
|
||||
httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -157,12 +164,12 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
}
|
||||
r0mux.Handle("/joined_rooms",
|
||||
v3mux.Handle("/joined_rooms",
|
||||
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetJoinedRooms(req, device, rsAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/join",
|
||||
v3mux.Handle("/rooms/{roomID}/join",
|
||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -176,7 +183,7 @@ func Setup(
|
|||
)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/leave",
|
||||
v3mux.Handle("/rooms/{roomID}/leave",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -190,7 +197,7 @@ func Setup(
|
|||
)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/unpeek",
|
||||
v3mux.Handle("/rooms/{roomID}/unpeek",
|
||||
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -201,7 +208,7 @@ func Setup(
|
|||
)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/ban",
|
||||
v3mux.Handle("/rooms/{roomID}/ban",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -210,7 +217,7 @@ func Setup(
|
|||
return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/invite",
|
||||
v3mux.Handle("/rooms/{roomID}/invite",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -222,7 +229,7 @@ func Setup(
|
|||
return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/kick",
|
||||
v3mux.Handle("/rooms/{roomID}/kick",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -231,7 +238,7 @@ func Setup(
|
|||
return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/unban",
|
||||
v3mux.Handle("/rooms/{roomID}/unban",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -240,7 +247,7 @@ func Setup(
|
|||
return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/send/{eventType}",
|
||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -249,7 +256,7 @@ func Setup(
|
|||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
|
||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -260,7 +267,7 @@ func Setup(
|
|||
nil, cfg, rsAPI, transactionsCache)
|
||||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/event/{eventID}",
|
||||
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
|
||||
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -270,7 +277,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
@ -278,7 +285,7 @@ func Setup(
|
|||
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
@ -286,7 +293,7 @@ func Setup(
|
|||
return GetAliases(req, rsAPI, device, vars["roomID"])
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
@ -297,7 +304,7 @@ func Setup(
|
|||
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
@ -306,7 +313,7 @@ func Setup(
|
|||
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
|
||||
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -318,7 +325,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -329,21 +336,21 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
return Register(req, userAPI, accountDB, cfg)
|
||||
})).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
|
||||
v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
}
|
||||
return RegisterAvailable(req, cfg, accountDB)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/directory/room/{roomAlias}",
|
||||
v3mux.Handle("/directory/room/{roomAlias}",
|
||||
httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -353,7 +360,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/directory/room/{roomAlias}",
|
||||
v3mux.Handle("/directory/room/{roomAlias}",
|
||||
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -363,7 +370,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/directory/room/{roomAlias}",
|
||||
v3mux.Handle("/directory/room/{roomAlias}",
|
||||
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -372,7 +379,7 @@ func Setup(
|
|||
return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI)
|
||||
}),
|
||||
).Methods(http.MethodDelete, http.MethodOptions)
|
||||
r0mux.Handle("/directory/list/room/{roomID}",
|
||||
v3mux.Handle("/directory/list/room/{roomID}",
|
||||
httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -382,7 +389,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
// TODO: Add AS support
|
||||
r0mux.Handle("/directory/list/room/{roomID}",
|
||||
v3mux.Handle("/directory/list/room/{roomID}",
|
||||
httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -391,25 +398,25 @@ func Setup(
|
|||
return SetVisibility(req, rsAPI, device, vars["roomID"])
|
||||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
r0mux.Handle("/publicRooms",
|
||||
v3mux.Handle("/publicRooms",
|
||||
httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
|
||||
return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/logout",
|
||||
v3mux.Handle("/logout",
|
||||
httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return Logout(req, userAPI, device)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/logout/all",
|
||||
v3mux.Handle("/logout/all",
|
||||
httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return LogoutAll(req, userAPI, device)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/typing/{userID}",
|
||||
v3mux.Handle("/rooms/{roomID}/typing/{userID}",
|
||||
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -421,7 +428,7 @@ func Setup(
|
|||
return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI)
|
||||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/redact/{eventID}",
|
||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
|
||||
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -430,7 +437,7 @@ func Setup(
|
|||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
||||
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -440,7 +447,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/sendToDevice/{eventType}/{txnID}",
|
||||
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
|
||||
httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -465,7 +472,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/account/whoami",
|
||||
v3mux.Handle("/account/whoami",
|
||||
httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -474,7 +481,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/account/password",
|
||||
v3mux.Handle("/account/password",
|
||||
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -483,7 +490,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/account/deactivate",
|
||||
v3mux.Handle("/account/deactivate",
|
||||
httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -494,7 +501,7 @@ func Setup(
|
|||
|
||||
// Stub endpoints required by Element
|
||||
|
||||
r0mux.Handle("/login",
|
||||
v3mux.Handle("/login",
|
||||
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -503,7 +510,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/auth/{authType}/fallback/web",
|
||||
v3mux.Handle("/auth/{authType}/fallback/web",
|
||||
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
||||
vars := mux.Vars(req)
|
||||
return AuthFallback(w, req, vars["authType"], cfg)
|
||||
|
|
@ -512,7 +519,7 @@ func Setup(
|
|||
|
||||
// Push rules
|
||||
|
||||
r0mux.Handle("/pushrules",
|
||||
v3mux.Handle("/pushrules",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -521,13 +528,13 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushrules/",
|
||||
v3mux.Handle("/pushrules/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetAllPushRules(req.Context(), device, psAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushrules/",
|
||||
v3mux.Handle("/pushrules/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -536,7 +543,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/",
|
||||
v3mux.Handle("/pushrules/{scope}/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -546,7 +553,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}",
|
||||
v3mux.Handle("/pushrules/{scope}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -555,7 +562,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope:[^/]+/?}",
|
||||
v3mux.Handle("/pushrules/{scope:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -564,7 +571,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/{kind}/",
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -574,7 +581,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/{kind}",
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -583,7 +590,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
|
||||
v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -592,7 +599,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -602,7 +609,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -616,7 +623,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -626,7 +633,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodDelete)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -636,7 +643,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
|
||||
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -648,7 +655,7 @@ func Setup(
|
|||
|
||||
// Element user settings
|
||||
|
||||
r0mux.Handle("/profile/{userID}",
|
||||
v3mux.Handle("/profile/{userID}",
|
||||
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -658,7 +665,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/profile/{userID}/avatar_url",
|
||||
v3mux.Handle("/profile/{userID}/avatar_url",
|
||||
httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -668,7 +675,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/profile/{userID}/avatar_url",
|
||||
v3mux.Handle("/profile/{userID}/avatar_url",
|
||||
httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -683,7 +690,7 @@ func Setup(
|
|||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||
// PUT requests, so we need to allow this method
|
||||
|
||||
r0mux.Handle("/profile/{userID}/displayname",
|
||||
v3mux.Handle("/profile/{userID}/displayname",
|
||||
httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -693,7 +700,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/profile/{userID}/displayname",
|
||||
v3mux.Handle("/profile/{userID}/displayname",
|
||||
httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -708,13 +715,13 @@ func Setup(
|
|||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||
// PUT requests, so we need to allow this method
|
||||
|
||||
r0mux.Handle("/account/3pid",
|
||||
v3mux.Handle("/account/3pid",
|
||||
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetAssociated3PIDs(req, accountDB, device)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/account/3pid",
|
||||
v3mux.Handle("/account/3pid",
|
||||
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return CheckAndSave3PIDAssociation(req, accountDB, device, cfg)
|
||||
}),
|
||||
|
|
@ -726,14 +733,14 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
|
||||
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
|
||||
httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
|
||||
return RequestEmailToken(req, accountDB, cfg)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// Element logs get flooded unless this is handled
|
||||
r0mux.Handle("/presence/{userID}/status",
|
||||
v3mux.Handle("/presence/{userID}/status",
|
||||
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -746,7 +753,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/voip/turnServer",
|
||||
v3mux.Handle("/voip/turnServer",
|
||||
httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -755,7 +762,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/thirdparty/protocols",
|
||||
v3mux.Handle("/thirdparty/protocols",
|
||||
httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
|
||||
// TODO: Return the third party protcols
|
||||
return util.JSONResponse{
|
||||
|
|
@ -765,7 +772,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/initialSync",
|
||||
v3mux.Handle("/rooms/{roomID}/initialSync",
|
||||
httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse {
|
||||
// TODO: Allow people to peek into rooms.
|
||||
return util.JSONResponse{
|
||||
|
|
@ -775,7 +782,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userID}/account_data/{type}",
|
||||
v3mux.Handle("/user/{userID}/account_data/{type}",
|
||||
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -785,7 +792,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
|
||||
v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
|
||||
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -795,7 +802,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userID}/account_data/{type}",
|
||||
v3mux.Handle("/user/{userID}/account_data/{type}",
|
||||
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -805,7 +812,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
|
||||
v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
|
||||
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -815,7 +822,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
r0mux.Handle("/admin/whois/{userID}",
|
||||
v3mux.Handle("/admin/whois/{userID}",
|
||||
httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -825,7 +832,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
r0mux.Handle("/user/{userID}/openid/request_token",
|
||||
v3mux.Handle("/user/{userID}/openid/request_token",
|
||||
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -838,7 +845,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user_directory/search",
|
||||
v3mux.Handle("/user_directory/search",
|
||||
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -863,7 +870,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/members",
|
||||
v3mux.Handle("/rooms/{roomID}/members",
|
||||
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -873,7 +880,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/joined_members",
|
||||
v3mux.Handle("/rooms/{roomID}/joined_members",
|
||||
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -883,7 +890,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/read_markers",
|
||||
v3mux.Handle("/rooms/{roomID}/read_markers",
|
||||
httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -896,7 +903,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/forget",
|
||||
v3mux.Handle("/rooms/{roomID}/forget",
|
||||
httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -909,13 +916,13 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/devices",
|
||||
v3mux.Handle("/devices",
|
||||
httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetDevicesByLocalpart(req, userAPI, device)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/devices/{deviceID}",
|
||||
v3mux.Handle("/devices/{deviceID}",
|
||||
httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -925,7 +932,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/devices/{deviceID}",
|
||||
v3mux.Handle("/devices/{deviceID}",
|
||||
httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -935,7 +942,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/devices/{deviceID}",
|
||||
v3mux.Handle("/devices/{deviceID}",
|
||||
httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -945,7 +952,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodDelete, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/delete_devices",
|
||||
v3mux.Handle("/delete_devices",
|
||||
httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return DeleteDevices(req, userAPI, device)
|
||||
}),
|
||||
|
|
@ -957,13 +964,13 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushers",
|
||||
v3mux.Handle("/pushers",
|
||||
httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetPushers(req, device, psAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/pushers/set",
|
||||
v3mux.Handle("/pushers/set",
|
||||
httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -973,7 +980,7 @@ func Setup(
|
|||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// Stub implementations for sytest
|
||||
r0mux.Handle("/events",
|
||||
v3mux.Handle("/events",
|
||||
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
|
||||
"chunk": []interface{}{},
|
||||
|
|
@ -983,7 +990,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/initialSync",
|
||||
v3mux.Handle("/initialSync",
|
||||
httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
|
||||
"end": "",
|
||||
|
|
@ -991,7 +998,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags",
|
||||
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags",
|
||||
httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -1001,7 +1008,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
|
||||
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
|
||||
httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -1011,7 +1018,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
|
||||
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
|
||||
httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -1021,7 +1028,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodDelete, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/capabilities",
|
||||
v3mux.Handle("/capabilities",
|
||||
httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -1064,11 +1071,11 @@ func Setup(
|
|||
return CreateKeyBackupVersion(req, userAPI, device)
|
||||
})
|
||||
|
||||
r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut)
|
||||
r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete)
|
||||
r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut)
|
||||
v3mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete)
|
||||
v3mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||
unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
|
@ -1160,9 +1167,9 @@ func Setup(
|
|||
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
|
||||
})
|
||||
|
||||
r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
|
||||
r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
|
||||
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut)
|
||||
v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
|
||||
v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
|
||||
v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut)
|
||||
|
||||
unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
|
||||
unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
|
||||
|
|
@ -1190,9 +1197,9 @@ func Setup(
|
|||
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
|
||||
})
|
||||
|
||||
r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
|
||||
unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
|
@ -1210,34 +1217,34 @@ func Setup(
|
|||
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
|
||||
})
|
||||
|
||||
r0mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// Supplying a device ID is deprecated.
|
||||
r0mux.Handle("/keys/upload/{deviceID}",
|
||||
v3mux.Handle("/keys/upload/{deviceID}",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadKeys(req, keyAPI, device)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/keys/upload",
|
||||
v3mux.Handle("/keys/upload",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadKeys(req, keyAPI, device)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/keys/query",
|
||||
v3mux.Handle("/keys/query",
|
||||
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return QueryKeys(req, keyAPI, device)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/keys/claim",
|
||||
v3mux.Handle("/keys/claim",
|
||||
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return ClaimKeys(req, keyAPI)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
r0mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
|
||||
v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
|
||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req); r != nil {
|
||||
return *r
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/eduserver/api"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ type typingContentJSON struct {
|
|||
// sends the typing events to client API typingProducer
|
||||
func SendTyping(
|
||||
req *http.Request, device *userapi.Device, roomID string,
|
||||
userID string, accountDB accounts.Database,
|
||||
userID string, accountDB userdb.Database,
|
||||
eduAPI api.EDUServerInputAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/threepid"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -40,7 +40,7 @@ type threePIDsResponse struct {
|
|||
// RequestEmailToken implements:
|
||||
// POST /account/3pid/email/requestToken
|
||||
// POST /register/email/requestToken
|
||||
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse {
|
||||
func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse {
|
||||
var body threepid.EmailAssociationRequest
|
||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||
return *reqErr
|
||||
|
|
@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
|
|||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.MatrixError{
|
||||
ErrCode: "M_THREEPID_IN_USE",
|
||||
Err: accounts.Err3PIDInUse.Error(),
|
||||
Err: userdb.Err3PIDInUse.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
|
|||
|
||||
// CheckAndSave3PIDAssociation implements POST /account/3pid
|
||||
func CheckAndSave3PIDAssociation(
|
||||
req *http.Request, accountDB accounts.Database, device *api.Device,
|
||||
req *http.Request, accountDB userdb.Database, device *api.Device,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
var body threepid.EmailAssociationCheckRequest
|
||||
|
|
@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation(
|
|||
|
||||
// GetAssociated3PIDs implements GET /account/3pid
|
||||
func GetAssociated3PIDs(
|
||||
req *http.Request, accountDB accounts.Database, device *api.Device,
|
||||
req *http.Request, accountDB userdb.Database, device *api.Device,
|
||||
) util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
|
|
@ -170,7 +170,7 @@ func GetAssociated3PIDs(
|
|||
}
|
||||
|
||||
// Forget3PID implements POST /account/3pid/delete
|
||||
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
|
||||
func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse {
|
||||
var body authtypes.ThreePID
|
||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||
return *reqErr
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ import (
|
|||
// whoamiResponse represents an response for a `whoami` request
|
||||
type whoamiResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
DeviceID string `json:"device_id"`
|
||||
IsGuest bool `json:"is_guest"`
|
||||
}
|
||||
|
||||
// Whoami implements `/account/whoami` which enables client to query their account user id.
|
||||
|
|
@ -29,6 +31,10 @@ type whoamiResponse struct {
|
|||
func Whoami(req *http.Request, device *api.Device) util.JSONResponse {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: whoamiResponse{UserID: device.UserID},
|
||||
JSON: whoamiResponse{
|
||||
UserID: device.UserID,
|
||||
DeviceID: device.ID,
|
||||
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ var (
|
|||
func CheckAndProcessInvite(
|
||||
ctx context.Context,
|
||||
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
|
||||
rsAPI api.RoomserverInternalAPI, db accounts.Database,
|
||||
rsAPI api.RoomserverInternalAPI, db userdb.Database,
|
||||
roomID string,
|
||||
evTime time.Time,
|
||||
) (inviteStoredOnIDServer bool, err error) {
|
||||
|
|
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
|
|||
// Returns an error if a check or a request failed.
|
||||
func queryIDServer(
|
||||
ctx context.Context,
|
||||
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
|
||||
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
|
||||
body *MembershipRequest, roomID string,
|
||||
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
|
||||
if err = isTrusted(body.IDServer, cfg); err != nil {
|
||||
|
|
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
|
|||
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
||||
func queryIDServerStoreInvite(
|
||||
ctx context.Context,
|
||||
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
|
||||
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
|
||||
body *MembershipRequest, roomID string,
|
||||
) (*idServerStoreInviteResponse, error) {
|
||||
// Retrieve the sender's profile to get their display name
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
const usage = `Usage: %s
|
||||
|
|
@ -77,9 +77,14 @@ func main() {
|
|||
|
||||
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
|
||||
|
||||
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
|
||||
accountDB, err := userdb.NewDatabase(
|
||||
&config.DatabaseOptions{
|
||||
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
|
||||
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS)
|
||||
},
|
||||
cfg.Global.ServerName, bcrypt.DefaultCost,
|
||||
cfg.UserAPI.OpenIDTokenLifetimeMS,
|
||||
api.DefaultLoginTokenLifetime,
|
||||
)
|
||||
if err != nil {
|
||||
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -127,7 +127,6 @@ func main() {
|
|||
cfg.FederationAPI.FederationMaxRetries = 6
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||
|
|
|
|||
|
|
@ -160,7 +160,6 @@ func main() {
|
|||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||
|
|
|
|||
|
|
@ -80,7 +80,6 @@ func main() {
|
|||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||
|
|
|
|||
|
|
@ -164,7 +164,6 @@ func startup() {
|
|||
cfg.Defaults(true)
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
|
||||
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
|
||||
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
|
||||
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
|
||||
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
|
||||
|
|
|
|||
|
|
@ -168,7 +168,6 @@ func main() {
|
|||
cfg.Defaults(true)
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
|
||||
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
|
||||
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
|
||||
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
|
||||
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ func main() {
|
|||
cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI)
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI)
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI)
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI)
|
||||
}
|
||||
cfg.Global.TrustedIDServers = []string{
|
||||
"matrix.org",
|
||||
|
|
|
|||
|
|
@ -8,12 +8,11 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
|
||||
pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
|
||||
slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
|
||||
pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
|
||||
sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
|
||||
"github.com/pressly/goose"
|
||||
|
||||
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
|
@ -26,8 +25,7 @@ const (
|
|||
RoomServer = "roomserver"
|
||||
SigningKeyServer = "signingkeyserver"
|
||||
SyncAPI = "syncapi"
|
||||
UserAPIAccounts = "userapi_accounts"
|
||||
UserAPIDevices = "userapi_devices"
|
||||
UserAPI = "userapi"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -35,7 +33,7 @@ var (
|
|||
flags = flag.NewFlagSet("goose", flag.ExitOnError)
|
||||
component = flags.String("component", "", "dendrite component name")
|
||||
knownDBs = []string{
|
||||
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices,
|
||||
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -143,18 +141,14 @@ Commands:
|
|||
|
||||
func loadSQLiteDeltas(component string) {
|
||||
switch component {
|
||||
case UserAPIAccounts:
|
||||
slaccounts.LoadFromGoose()
|
||||
case UserAPIDevices:
|
||||
sldevices.LoadFromGoose()
|
||||
case UserAPI:
|
||||
slusers.LoadFromGoose()
|
||||
}
|
||||
}
|
||||
|
||||
func loadPostgresDeltas(component string) {
|
||||
switch component {
|
||||
case UserAPIAccounts:
|
||||
pgaccounts.LoadFromGoose()
|
||||
case UserAPIDevices:
|
||||
pgdevices.LoadFromGoose()
|
||||
case UserAPI:
|
||||
pgusers.LoadFromGoose()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -142,6 +142,10 @@ client_api:
|
|||
# using the registration shared secret below.
|
||||
registration_disabled: false
|
||||
|
||||
# Prevents new guest accounts from being created. Guest registration is also
|
||||
# disabled implicitly by setting 'registration_disabled' above.
|
||||
guests_disabled: true
|
||||
|
||||
# If set, allows registration by anyone who knows the shared secret, regardless of
|
||||
# whether registration is otherwise disabled.
|
||||
registration_shared_secret: ""
|
||||
|
|
|
|||
|
|
@ -7,14 +7,6 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
RoomServerStateKeyNIDsCacheName = "roomserver_statekey_nids"
|
||||
RoomServerStateKeyNIDsCacheMaxEntries = 1024
|
||||
RoomServerStateKeyNIDsCacheMutable = false
|
||||
|
||||
RoomServerEventTypeNIDsCacheName = "roomserver_eventtype_nids"
|
||||
RoomServerEventTypeNIDsCacheMaxEntries = 64
|
||||
RoomServerEventTypeNIDsCacheMutable = false
|
||||
|
||||
RoomServerRoomIDsCacheName = "roomserver_room_ids"
|
||||
RoomServerRoomIDsCacheMaxEntries = 1024
|
||||
RoomServerRoomIDsCacheMutable = false
|
||||
|
|
@ -29,44 +21,10 @@ type RoomServerCaches interface {
|
|||
// RoomServerNIDsCache contains the subset of functions needed for
|
||||
// a roomserver NID cache.
|
||||
type RoomServerNIDsCache interface {
|
||||
GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool)
|
||||
StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID)
|
||||
|
||||
GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool)
|
||||
StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID)
|
||||
|
||||
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
|
||||
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
|
||||
}
|
||||
|
||||
func (c Caches) GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) {
|
||||
val, found := c.RoomServerStateKeyNIDs.Get(stateKey)
|
||||
if found && val != nil {
|
||||
if stateKeyNID, ok := val.(types.EventStateKeyNID); ok {
|
||||
return stateKeyNID, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (c Caches) StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) {
|
||||
c.RoomServerStateKeyNIDs.Set(stateKey, nid)
|
||||
}
|
||||
|
||||
func (c Caches) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) {
|
||||
val, found := c.RoomServerEventTypeNIDs.Get(eventType)
|
||||
if found && val != nil {
|
||||
if eventTypeNID, ok := val.(types.EventTypeNID); ok {
|
||||
return eventTypeNID, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) {
|
||||
c.RoomServerEventTypeNIDs.Set(eventType, nid)
|
||||
}
|
||||
|
||||
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
|
||||
val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID)))
|
||||
if found && val != nil {
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ package caching
|
|||
type Caches struct {
|
||||
RoomVersions Cache // RoomVersionCache
|
||||
ServerKeys Cache // ServerKeyCache
|
||||
RoomServerStateKeyNIDs Cache // RoomServerNIDsCache
|
||||
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache
|
||||
RoomServerRoomNIDs Cache // RoomServerNIDsCache
|
||||
RoomServerRoomIDs Cache // RoomServerNIDsCache
|
||||
RoomInfos Cache // RoomInfoCache
|
||||
|
|
|
|||
|
|
@ -28,24 +28,6 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomServerStateKeyNIDs, err := NewInMemoryLRUCachePartition(
|
||||
RoomServerStateKeyNIDsCacheName,
|
||||
RoomServerStateKeyNIDsCacheMutable,
|
||||
RoomServerStateKeyNIDsCacheMaxEntries,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomServerEventTypeNIDs, err := NewInMemoryLRUCachePartition(
|
||||
RoomServerEventTypeNIDsCacheName,
|
||||
RoomServerEventTypeNIDsCacheMutable,
|
||||
RoomServerEventTypeNIDsCacheMaxEntries,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
|
||||
RoomServerRoomIDsCacheName,
|
||||
RoomServerRoomIDsCacheMutable,
|
||||
|
|
@ -74,15 +56,12 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
return nil, err
|
||||
}
|
||||
go cacheCleaner(
|
||||
roomVersions, serverKeys, roomServerStateKeyNIDs,
|
||||
roomServerEventTypeNIDs, roomServerRoomIDs,
|
||||
roomVersions, serverKeys, roomServerRoomIDs,
|
||||
roomInfos, federationEvents,
|
||||
)
|
||||
return &Caches{
|
||||
RoomVersions: roomVersions,
|
||||
ServerKeys: serverKeys,
|
||||
RoomServerStateKeyNIDs: roomServerStateKeyNIDs,
|
||||
RoomServerEventTypeNIDs: roomServerEventTypeNIDs,
|
||||
RoomServerRoomIDs: roomServerRoomIDs,
|
||||
RoomInfos: roomInfos,
|
||||
FederationEvents: federationEvents,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
|
|||
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database)
|
||||
|
||||
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
|
||||
cfg.EDUServer.InternalAPI.Listen = assignAddress()
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
|
|||
}
|
||||
|
||||
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
|
||||
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
|
||||
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
||||
|
|
@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
domain := string(serverName)
|
||||
// query local devices
|
||||
if serverName == a.ThisServer {
|
||||
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
||||
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
||||
|
|
@ -525,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
|||
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
||||
) error {
|
||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
||||
|
|
@ -554,10 +554,58 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
|||
}
|
||||
|
||||
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
// get a list of devices from the user API that actually exist, as
|
||||
// we won't store keys for devices that don't exist
|
||||
uapidevices := &userapi.QueryDevicesResponse{}
|
||||
if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
return
|
||||
}
|
||||
if !uapidevices.UserExists {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("user %q does not exist", req.UserID),
|
||||
}
|
||||
return
|
||||
}
|
||||
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
|
||||
for _, key := range uapidevices.Devices {
|
||||
existingDeviceMap[key.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// Get all of the user existing device keys so we can check for changes.
|
||||
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Work out whether we have device keys in the keyserver for devices that
|
||||
// no longer exist in the user API. This is mostly an exercise to ensure
|
||||
// that we keep some integrity between the two.
|
||||
var toClean []gomatrixserverlib.KeyID
|
||||
for _, k := range existingKeys {
|
||||
if _, ok := existingDeviceMap[k.DeviceID]; !ok {
|
||||
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
|
||||
}
|
||||
}
|
||||
|
||||
if len(toClean) > 0 {
|
||||
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
|
||||
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
|
||||
} else {
|
||||
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
|
||||
}
|
||||
}
|
||||
|
||||
var keysToStore []api.DeviceMessage
|
||||
// assert that the user ID / device ID are not lying for each key
|
||||
for _, key := range req.DeviceKeys {
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
|
||||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
|
|
@ -568,6 +616,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||
continue // deleted keys don't need sanity checking
|
||||
}
|
||||
// check that the device in question actually exists in the user
|
||||
// API before we try and store a key for it
|
||||
if _, ok := existingDeviceMap[key.DeviceID]; !ok {
|
||||
continue
|
||||
}
|
||||
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
||||
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
||||
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
||||
|
|
@ -583,29 +636,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||
})
|
||||
}
|
||||
|
||||
// get existing device keys so we can check for changes
|
||||
existingKeys := make([]api.DeviceMessage, len(keysToStore))
|
||||
for i := range keysToStore {
|
||||
existingKeys[i] = api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
UserID: keysToStore[i].UserID,
|
||||
DeviceID: keysToStore[i].DeviceID,
|
||||
},
|
||||
}
|
||||
}
|
||||
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
||||
}
|
||||
return
|
||||
}
|
||||
if req.OnlyDisplayNameUpdates {
|
||||
// add the display name field from keysToStore into existingKeys
|
||||
keysToStore = appendDisplayNames(existingKeys, keysToStore)
|
||||
}
|
||||
// store the device keys and emit changes
|
||||
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ type Database interface {
|
|||
|
||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
|
||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||
// cross-signing signatures relating to that device.
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
|
|||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
|
|
@ -73,6 +76,7 @@ type deviceKeysStatements struct {
|
|||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
|
|
@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
|
||||
}
|
||||
|
||||
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
|
|||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
|
|
@ -69,6 +72,7 @@ type deviceKeysStatements struct {
|
|||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
|
|
@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
}
|
||||
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ type DeviceKeys interface {
|
|||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,23 +59,12 @@ func (d *Database) eventTypeNIDs(
|
|||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
result := make(map[string]types.EventTypeNID)
|
||||
remaining := []string{}
|
||||
for _, eventType := range eventTypes {
|
||||
if nid, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
|
||||
result[eventType] = nid
|
||||
} else {
|
||||
remaining = append(remaining, eventType)
|
||||
}
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
|
||||
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for eventType, nid := range nids {
|
||||
result[eventType] = nid
|
||||
d.Cache.StoreRoomServerEventTypeNID(eventType, nid)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -96,23 +85,12 @@ func (d *Database) eventStateKeyNIDs(
|
|||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
result := make(map[string]types.EventStateKeyNID)
|
||||
remaining := []string{}
|
||||
for _, eventStateKey := range eventStateKeys {
|
||||
if nid, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
|
||||
result[eventStateKey] = nid
|
||||
} else {
|
||||
remaining = append(remaining, eventStateKey)
|
||||
}
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
|
||||
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for eventStateKey, nid := range nids {
|
||||
result[eventStateKey] = nid
|
||||
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, nid)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -718,9 +696,6 @@ func (d *Database) assignRoomNID(
|
|||
func (d *Database) assignEventTypeNID(
|
||||
ctx context.Context, txn *sql.Tx, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
if eventTypeNID, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
|
||||
return eventTypeNID, nil
|
||||
}
|
||||
// Check if we already have a numeric ID in the database.
|
||||
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -731,18 +706,12 @@ func (d *Database) assignEventTypeNID(
|
|||
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
d.Cache.StoreRoomServerEventTypeNID(eventType, eventTypeNID)
|
||||
}
|
||||
return eventTypeNID, err
|
||||
}
|
||||
|
||||
func (d *Database) assignStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
if eventStateKeyNID, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
|
||||
return eventStateKeyNID, nil
|
||||
}
|
||||
// Check if we already have a numeric ID in the database.
|
||||
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -753,9 +722,6 @@ func (d *Database) assignStateKeyNID(
|
|||
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, eventStateKeyNID)
|
||||
}
|
||||
return eventStateKeyNID, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
|
|
@ -290,8 +290,14 @@ func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
|
|||
|
||||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||
// be called once per component.
|
||||
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
|
||||
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
|
||||
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
||||
db, err := userdb.NewDatabase(
|
||||
&b.Cfg.UserAPI.AccountDatabase,
|
||||
b.Cfg.Global.ServerName,
|
||||
b.Cfg.UserAPI.BCryptCost,
|
||||
b.Cfg.UserAPI.OpenIDTokenLifetimeMS,
|
||||
userapi.DefaultLoginTokenLifetime,
|
||||
)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ type ClientAPI struct {
|
|||
// If set, allows registration by anyone who also has the shared
|
||||
// secret, even if registration is otherwise disabled.
|
||||
RegistrationSharedSecret string `yaml:"registration_shared_secret"`
|
||||
// If set, prevents guest accounts from being created. Only takes
|
||||
// effect if registration is enabled, otherwise guests registration
|
||||
// is forbidden either way.
|
||||
GuestsDisabled bool `yaml:"guests_disabled"`
|
||||
|
||||
// Boolean stating whether catpcha registration is enabled
|
||||
// and required
|
||||
|
|
|
|||
|
|
@ -16,9 +16,6 @@ type UserAPI struct {
|
|||
// The Account database stores the login details and account information
|
||||
// for local users. It is accessed by the UserAPI.
|
||||
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
||||
// The Device database stores session information for the devices of logged
|
||||
// in local users. It is accessed by the UserAPI.
|
||||
DeviceDatabase DatabaseOptions `yaml:"device_database"`
|
||||
}
|
||||
|
||||
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
|
||||
|
|
@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) {
|
|||
c.InternalAPI.Listen = "http://localhost:7781"
|
||||
c.InternalAPI.Connect = "http://localhost:7781"
|
||||
c.AccountDatabase.Defaults(10)
|
||||
c.DeviceDatabase.Defaults(10)
|
||||
if generate {
|
||||
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
|
||||
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
|
||||
}
|
||||
c.BCryptCost = bcrypt.DefaultCost
|
||||
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
|
||||
|
|
@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
|||
checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen))
|
||||
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
|
||||
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
|
||||
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
|
||||
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/syncapi"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ import (
|
|||
// all components of Dendrite, for use in monolith mode.
|
||||
type Monolith struct {
|
||||
Config *config.Dendrite
|
||||
AccountDB accounts.Database
|
||||
AccountDB userdb.Database
|
||||
KeyRing *gomatrixserverlib.KeyRing
|
||||
Client *gomatrixserverlib.Client
|
||||
FedClient *gomatrixserverlib.FederationClient
|
||||
|
|
|
|||
|
|
@ -39,14 +39,14 @@ func Setup(
|
|||
rsAPI api.RoomserverInternalAPI,
|
||||
cfg *config.SyncAPI,
|
||||
) {
|
||||
r0mux := csMux.PathPrefix("/r0").Subrouter()
|
||||
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
|
||||
|
||||
// TODO: Add AS support for all handlers below.
|
||||
r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return srp.OnIncomingSyncRequest(req, device)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
@ -54,7 +54,7 @@ func Setup(
|
|||
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userId}/filter",
|
||||
v3mux.Handle("/user/{userId}/filter",
|
||||
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -64,7 +64,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userId}/filter/{filterId}",
|
||||
v3mux.Handle("/user/{userId}/filter/{filterId}",
|
||||
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
|
|
@ -74,7 +74,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return srp.OnIncomingKeyChangeRequest(req, device)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,13 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// DefaultLoginTokenLifetime determines how old a valid token may be.
|
||||
//
|
||||
// NOTSPEC: The current spec says "SHOULD be limited to around five
|
||||
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
|
||||
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
|
||||
const DefaultLoginTokenLifetime = 2 * time.Minute
|
||||
|
||||
type LoginTokenInternalAPI interface {
|
||||
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
|
||||
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error
|
||||
|
|
|
|||
|
|
@ -31,13 +31,11 @@ import (
|
|||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
type UserInternalAPI struct {
|
||||
AccountDB accounts.Database
|
||||
DeviceDB devices.Database
|
||||
DB storage.Database
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
// AppServices is the list of all registered AS
|
||||
AppServices []config.ApplicationService
|
||||
|
|
@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
|||
if req.DataType == "" {
|
||||
return fmt.Errorf("data type must not be empty")
|
||||
}
|
||||
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
|
||||
return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
|
|
@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
return nil
|
||||
}
|
||||
|
||||
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
||||
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
res.PasswordUpdated = true
|
||||
|
|
@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
|||
"device_id": req.DeviceID,
|
||||
"display_name": req.DeviceDisplayName,
|
||||
}).Info("PerformDeviceCreation")
|
||||
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
|||
deletedDeviceIDs := req.DeviceIDs
|
||||
if len(req.DeviceIDs) == 0 {
|
||||
var devices []api.Device
|
||||
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
||||
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
||||
for _, d := range devices {
|
||||
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
||||
}
|
||||
} else {
|
||||
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
|
||||
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
|
|||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
|
||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
|
||||
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
|
@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
|||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return err
|
||||
}
|
||||
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
||||
if err == sql.ErrNoRows {
|
||||
res.DeviceExists = false
|
||||
return nil
|
||||
|
|
@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
|||
return nil
|
||||
}
|
||||
|
||||
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
||||
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||
return err
|
||||
|
|
@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
|||
if domain != a.ServerName {
|
||||
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
|
||||
}
|
||||
prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local)
|
||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
|
@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
|
||||
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
|
||||
profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
|
||||
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
|
||||
devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
|
|||
if domain != a.ServerName {
|
||||
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
|
||||
}
|
||||
devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local)
|
||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.UserExists = true
|
||||
res.Devices = devs
|
||||
return nil
|
||||
}
|
||||
|
|
@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
}
|
||||
if req.DataType != "" {
|
||||
var data json.RawMessage
|
||||
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
}
|
||||
return nil
|
||||
}
|
||||
global, rooms, err := a.AccountDB.GetAccountData(ctx, local)
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
|||
|
||||
return nil
|
||||
}
|
||||
device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken)
|
||||
device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
|
@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
|||
|
||||
if localpart != "" { // AS is masquerading as another user
|
||||
// Verify that the user is registered
|
||||
account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart)
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
|
||||
// Verify that the account exists and either appServiceID matches or
|
||||
// it belongs to the appservice user namespaces
|
||||
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
||||
|
|
@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
|||
|
||||
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
|
||||
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
|
||||
err := a.AccountDB.DeactivateAccount(ctx, req.Localpart)
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart)
|
||||
res.AccountDeactivated = err == nil
|
||||
return err
|
||||
}
|
||||
|
|
@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
|||
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
|
||||
token := util.RandomString(24)
|
||||
|
||||
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
|
||||
exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
|
||||
|
||||
res.Token = api.OpenIDToken{
|
||||
Token: token,
|
||||
|
|
@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
|
|||
|
||||
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
|
||||
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
|
||||
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
|
||||
openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
|||
}
|
||||
return nil
|
||||
}
|
||||
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
|
||||
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
|
||||
}
|
||||
|
|
@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
|||
}
|
||||
// Create metadata
|
||||
if req.Version == "" {
|
||||
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
||||
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to create backup: %s", err)
|
||||
}
|
||||
|
|
@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
|||
}
|
||||
// Update metadata
|
||||
if len(req.Keys.Rooms) == 0 {
|
||||
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
|
||||
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to update backup: %s", err)
|
||||
}
|
||||
|
|
@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
|||
|
||||
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
|
||||
// you can only upload keys for the CURRENT version
|
||||
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
|
||||
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to query version: %s", err)
|
||||
return
|
||||
|
|
@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
|
|||
})
|
||||
}
|
||||
}
|
||||
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
|
||||
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
|
||||
return
|
||||
|
|
@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
|
||||
version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
|
||||
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
|
||||
res.Version = version
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
|
|||
res.Exists = !deleted
|
||||
|
||||
if !req.ReturnKeys {
|
||||
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
|
||||
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to count keys: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
||||
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to query keys: %s", err)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
|
|||
if domain != a.ServerName {
|
||||
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
|
||||
}
|
||||
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data)
|
||||
tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
|
|||
// PerformLoginTokenDeletion ensures the token doesn't exist.
|
||||
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
|
||||
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
|
||||
return a.DeviceDB.RemoveLoginToken(ctx, req.Token)
|
||||
return a.DB.RemoveLoginToken(ctx, req.Token)
|
||||
}
|
||||
|
||||
// QueryLoginToken returns the data associated with a login token. If
|
||||
// the token is not valid, success is returned, but res.Data == nil.
|
||||
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
|
||||
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token)
|
||||
tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
|
||||
if err != nil {
|
||||
res.Data = nil
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
|
|||
if domain != a.ServerName {
|
||||
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
|
||||
}
|
||||
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
||||
res.Data = nil
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -1,515 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
|
||||
|
||||
// Import the postgres database driver.
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Database represents an account database
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
sqlutil.PartitionOffsetStatements
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
accountDatas accountDataStatements
|
||||
threepids threepidStatements
|
||||
openIDTokens tokenStatements
|
||||
keyBackupVersions keyBackupVersionStatements
|
||||
keyBackups keyBackupStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
bcryptCost int
|
||||
openIDTokenLifetimeMS int64
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &Database{
|
||||
serverName: serverName,
|
||||
db: db,
|
||||
writer: sqlutil.NewDummyWriter(),
|
||||
bcryptCost: bcryptCost,
|
||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = d.accounts.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadIsActive(m)
|
||||
deltas.LoadAddAccountType(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accounts.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.profiles.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accountDatas.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.threepids.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.keyBackupVersions.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.keyBackups.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
) (*api.Account, error) {
|
||||
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||
func (d *Database) GetProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*authtypes.Profile, error) {
|
||||
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetAvatarURL(
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
) error {
|
||||
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
|
||||
}
|
||||
|
||||
// SetDisplayName updates the display name of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetDisplayName(
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
) error {
|
||||
return d.profiles.setDisplayName(ctx, localpart, displayName)
|
||||
}
|
||||
|
||||
// SetPassword sets the account password to the given hash.
|
||||
func (d *Database) SetPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
) error {
|
||||
hash, err := d.hashPassword(plaintextPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.accounts.updatePassword(ctx, localpart, hash)
|
||||
}
|
||||
|
||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, sqlutil.ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (acc *api.Account, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
// For guest accounts, we create a new numeric local part
|
||||
if accountType == api.AccountTypeGuest {
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
localpart = strconv.FormatInt(numLocalpart, 10)
|
||||
plaintextPassword = ""
|
||||
appserviceID = ""
|
||||
}
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
var account *api.Account
|
||||
var err error
|
||||
// Generate a password hash if this is not a password-less user
|
||||
hash := ""
|
||||
if plaintextPassword != "" {
|
||||
hash, err = d.hashPassword(plaintextPassword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||
if sqlutil.IsUniqueConstraintViolationErr(err) {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.serverName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// SaveAccountData saves new account data for a given user and a given room.
|
||||
// If the account data is not specific to a room, the room ID should be an empty string
|
||||
// If an account data already exists for a given set (user, room, data type), it will
|
||||
// update the corresponding row with the new content
|
||||
// Returns a SQL error if there was an issue with the insertion/update
|
||||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
})
|
||||
}
|
||||
|
||||
// GetAccountData returns account data related to a given localpart
|
||||
// If no account data could be found, returns an empty arrays
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
global map[string]json.RawMessage,
|
||||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
// localpart, room ID and type.
|
||||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.accountDatas.selectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
)
|
||||
}
|
||||
|
||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||
func (d *Database) GetNewNumericLocalpart(
|
||||
ctx context.Context,
|
||||
) (int64, error) {
|
||||
return d.accounts.selectNewNumericLocalpart(ctx, nil)
|
||||
}
|
||||
|
||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
|
||||
return string(hashBytes), err
|
||||
}
|
||||
|
||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||
// a third-party identifier which is already associated to a local user.
|
||||
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
||||
|
||||
// SaveThreePIDAssociation saves the association between a third party identifier
|
||||
// and a local Matrix user (identified by the user's ID's local part).
|
||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) SaveThreePIDAssociation(
|
||||
ctx context.Context, threepid, localpart, medium string,
|
||||
) (err error) {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
user, err := d.threepids.selectLocalpartForThreePID(
|
||||
ctx, txn, threepid, medium,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(user) > 0 {
|
||||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveThreePIDAssociation removes the association involving a given third-party
|
||||
// identifier.
|
||||
// If no association exists involving this third-party identifier, returns nothing.
|
||||
// If there was a problem talking to the database, returns an error.
|
||||
func (d *Database) RemoveThreePIDAssociation(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (err error) {
|
||||
return d.threepids.deleteThreePID(ctx, threepid, medium)
|
||||
}
|
||||
|
||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||
// identifier.
|
||||
// If no association involves the given third-party idenfitier, returns an empty
|
||||
// string.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) GetLocalpartForThreePID(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||
}
|
||||
|
||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||
// a given local user.
|
||||
// If no association is known for this user, returns an empty slice.
|
||||
// Returns an error if there was an issue talking to the database.
|
||||
func (d *Database) GetThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// CheckAccountAvailability checks if the username/localpart is already present
|
||||
// in the database.
|
||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
if err == sql.ErrNoRows {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||
// This function assumes the request is authenticated or the account data is used only internally.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// SearchProfiles returns all profiles where the provided localpart or display name
|
||||
// match any part of the profiles in the database.
|
||||
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
|
||||
) ([]authtypes.Profile, error) {
|
||||
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
|
||||
}
|
||||
|
||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||
return d.accounts.deactivateAccount(ctx, localpart)
|
||||
}
|
||||
|
||||
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
|
||||
func (d *Database) CreateOpenIDToken(
|
||||
ctx context.Context,
|
||||
token, localpart string,
|
||||
) (int64, error) {
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
||||
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
||||
func (d *Database) GetOpenIDTokenAttributes(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
||||
}
|
||||
|
||||
func (d *Database) CreateKeyBackup(
|
||||
ctx context.Context, userID, algorithm string, authData json.RawMessage,
|
||||
) (version string, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) UpdateKeyBackupAuthData(
|
||||
ctx context.Context, userID, version string, authData json.RawMessage,
|
||||
) (err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) DeleteKeyBackup(
|
||||
ctx context.Context, userID, version string,
|
||||
) (exists bool, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetKeyBackup(
|
||||
ctx context.Context, userID, version string,
|
||||
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetBackupKeys(
|
||||
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
||||
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
if filterSessionID != "" {
|
||||
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
||||
return err
|
||||
}
|
||||
if filterRoomID != "" {
|
||||
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
||||
return err
|
||||
}
|
||||
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) CountBackupKeys(
|
||||
ctx context.Context, version, userID string,
|
||||
) (count int64, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// nolint:nakedret
|
||||
func (d *Database) UpsertBackupKeys(
|
||||
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
||||
) (count int64, etag string, err error) {
|
||||
// wrap the following logic in a txn to ensure we atomically upload keys
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if deleted {
|
||||
return fmt.Errorf("backup was deleted")
|
||||
}
|
||||
// pull out all keys for this (user_id, version)
|
||||
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
changed := false
|
||||
// loop over all the new keys (which should be smaller than the set of backed up keys)
|
||||
for _, newKey := range uploads {
|
||||
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
|
||||
existingRoom := existingKeys[newKey.RoomID]
|
||||
if existingRoom != nil {
|
||||
existingSession, ok := existingRoom[newKey.SessionID]
|
||||
if ok {
|
||||
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
||||
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
||||
changed = true
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
|
||||
}
|
||||
}
|
||||
// if we shouldn't replace the key we do nothing with it
|
||||
continue
|
||||
}
|
||||
}
|
||||
// if we're here, either the room or session are new, either way, we insert
|
||||
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
|
||||
changed = true
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if changed {
|
||||
// update the etag
|
||||
var newETag string
|
||||
if oldETag == "" {
|
||||
newETag = "1"
|
||||
} else {
|
||||
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse old etag: %s", err)
|
||||
}
|
||||
newETag = strconv.FormatInt(oldETagInt+1, 10)
|
||||
}
|
||||
etag = newETag
|
||||
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
|
||||
} else {
|
||||
etag = oldETag
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package devices
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
||||
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
||||
|
||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||
RemoveLoginToken(ctx context.Context, token string) error
|
||||
|
||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||
// May return sql.ErrNoRows.
|
||||
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
||||
}
|
||||
|
|
@ -1,270 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const (
|
||||
// The length of generated device IDs
|
||||
deviceIDByteLength = 6
|
||||
loginTokenByteLength = 32
|
||||
)
|
||||
|
||||
// Database represents a device database.
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
devices devicesStatements
|
||||
loginTokens loginTokenStatements
|
||||
loginTokenLifetime time.Duration
|
||||
}
|
||||
|
||||
// NewDatabase creates a new device database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var d devicesStatements
|
||||
var lt loginTokenStatements
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = d.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = lt.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadLastSeenTSIP(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = lt.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Database{db, d, lt, loginTokenLifetime}, nil
|
||||
}
|
||||
|
||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByAccessToken(
|
||||
ctx context.Context, token string,
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByToken(ctx, token)
|
||||
}
|
||||
|
||||
// GetDeviceByID returns the device matching the given ID.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||
}
|
||||
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
// We generate device IDs in a loop in case its already taken.
|
||||
// We cap this at going round 5 times to ensure we don't spin forever
|
||||
var newDeviceID string
|
||||
for i := 1; i <= 5; i++ {
|
||||
newDeviceID, returnErr = generateDeviceID()
|
||||
if returnErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||
// random bytes.
|
||||
func generateDeviceID() (string, error) {
|
||||
b := make([]byte, deviceIDByteLength)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// url-safe no padding
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// UpdateDevice updates the given device with the display name.
|
||||
// Returns SQL error if there are problems and nil on success.
|
||||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevice revokes a device by deleting the entry in the database
|
||||
// matching with the given device ID and user ID localpart.
|
||||
// If the device doesn't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevice(
|
||||
ctx context.Context, deviceID, localpart string,
|
||||
) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||
// matching with the given device IDs and user ID localpart.
|
||||
// If the devices don't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevices(
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||
// database matching the given user ID localpart.
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveAllDevices(
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
) (devices []api.Device, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
||||
tok, err := generateLoginToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
meta := &api.LoginTokenMetadata{
|
||||
Token: tok,
|
||||
Expiration: time.Now().Add(d.loginTokenLifetime),
|
||||
}
|
||||
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.loginTokens.insert(ctx, txn, meta, data)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func generateLoginToken() (string, error) {
|
||||
b := make([]byte, loginTokenByteLength)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.loginTokens.deleteByToken(ctx, txn, token)
|
||||
})
|
||||
}
|
||||
|
||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||
// May return sql.ErrNoRows.
|
||||
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
return d.loginTokens.selectByToken(ctx, token)
|
||||
}
|
||||
|
|
@ -1,271 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const (
|
||||
// The length of generated device IDs
|
||||
deviceIDByteLength = 6
|
||||
|
||||
loginTokenByteLength = 32
|
||||
)
|
||||
|
||||
// Database represents a device database.
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
devices devicesStatements
|
||||
loginTokens loginTokenStatements
|
||||
loginTokenLifetime time.Duration
|
||||
}
|
||||
|
||||
// NewDatabase creates a new device database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer := sqlutil.NewExclusiveWriter()
|
||||
var d devicesStatements
|
||||
var lt loginTokenStatements
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = d.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = lt.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadLastSeenTSIP(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.prepare(db, writer, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = lt.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{db, writer, d, lt, loginTokenLifetime}, nil
|
||||
}
|
||||
|
||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByAccessToken(
|
||||
ctx context.Context, token string,
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByToken(ctx, token)
|
||||
}
|
||||
|
||||
// GetDeviceByID returns the device matching the given ID.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||
}
|
||||
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
// We generate device IDs in a loop in case its already taken.
|
||||
// We cap this at going round 5 times to ensure we don't spin forever
|
||||
var newDeviceID string
|
||||
for i := 1; i <= 5; i++ {
|
||||
newDeviceID, returnErr = generateDeviceID()
|
||||
if returnErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||
// random bytes.
|
||||
func generateDeviceID() (string, error) {
|
||||
b := make([]byte, deviceIDByteLength)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// url-safe no padding
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// UpdateDevice updates the given device with the display name.
|
||||
// Returns SQL error if there are problems and nil on success.
|
||||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevice revokes a device by deleting the entry in the database
|
||||
// matching with the given device ID and user ID localpart.
|
||||
// If the device doesn't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevice(
|
||||
ctx context.Context, deviceID, localpart string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||
// matching with the given device IDs and user ID localpart.
|
||||
// If the devices don't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevices(
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||
// database matching the given user ID localpart.
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveAllDevices(
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
) (devices []api.Device, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
||||
tok, err := generateLoginToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
meta := &api.LoginTokenMetadata{
|
||||
Token: tok,
|
||||
Expiration: time.Now().Add(d.loginTokenLifetime),
|
||||
}
|
||||
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.loginTokens.insert(ctx, txn, meta, data)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func generateLoginToken() (string, error) {
|
||||
b := make([]byte, loginTokenByteLength)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.loginTokens.deleteByToken(ctx, txn, token)
|
||||
})
|
||||
}
|
||||
|
||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||
// May return sql.ErrNoRows.
|
||||
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
return d.loginTokens.selectByToken(ctx, token)
|
||||
}
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !wasm
|
||||
// +build !wasm
|
||||
|
||||
package devices
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// and sets postgres connection parameters. loginTokenLifetime determines how long a
|
||||
// login token from CreateLoginToken is valid.
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package devices
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func NewDatabase(
|
||||
dbProperties *config.DatabaseOptions,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
loginTokenLifetime time.Duration,
|
||||
) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
|
|
@ -12,7 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package accounts
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
|
@ -60,6 +60,35 @@ type Database interface {
|
|||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||
|
||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
||||
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
||||
|
||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||
RemoveLoginToken(ctx context.Context, token string) error
|
||||
|
||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||
// May return sql.ErrNoRows.
|
||||
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
||||
}
|
||||
|
||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
|
@ -56,19 +57,20 @@ type accountDataStatements struct {
|
|||
selectAccountDataByTypeStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(accountDataSchema)
|
||||
func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
||||
s := &accountDataStatements{}
|
||||
_, err := db.Exec(accountDataSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertAccountDataStmt, insertAccountDataSQL},
|
||||
{&s.selectAccountDataStmt, selectAccountDataSQL},
|
||||
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) insertAccountData(
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||
|
|
@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) selectAccountData(
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
|
|
@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
|
|||
return global, rooms, rows.Err()
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) selectAccountDataByType(
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
|
@ -78,14 +79,15 @@ type accountsStatements struct {
|
|||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *accountsStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(accountsSchema)
|
||||
return err
|
||||
func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
|
||||
s := &accountsStatements{
|
||||
serverName: serverName,
|
||||
}
|
||||
|
||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.serverName = server
|
||||
return sqlutil.StatementList{
|
||||
_, err := db.Exec(accountsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertAccountStmt, insertAccountSQL},
|
||||
{&s.updatePasswordStmt, updatePasswordSQL},
|
||||
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
||||
|
|
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
|
|
@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) updatePassword(
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) deactivateAccount(
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectPasswordHash(
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectAccountByLocalpart(
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
|
|
@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectNewNumericLocalpart(
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
|
|
@ -5,13 +5,8 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/pressly/goose"
|
||||
)
|
||||
|
||||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
||||
}
|
||||
|
||||
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
||||
}
|
||||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
|
@ -111,50 +112,32 @@ type devicesStatements struct {
|
|||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
||||
func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
|
||||
s := &devicesStatements{
|
||||
serverName: serverName,
|
||||
}
|
||||
_, err := db.Exec(devicesSchema)
|
||||
return err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertDeviceStmt, insertDeviceSQL},
|
||||
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
|
||||
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
|
||||
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
|
||||
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
|
||||
{&s.deleteDeviceStmt, deleteDeviceSQL},
|
||||
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
|
||||
{&s.deleteDevicesStmt, deleteDevicesSQL},
|
||||
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
|
||||
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) insertDevice(
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
|
|
@ -176,7 +159,7 @@ func (s *devicesStatements) insertDevice(
|
|||
}
|
||||
|
||||
// deleteDevice removes a single device by id and user localpart.
|
||||
func (s *devicesStatements) deleteDevice(
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
|
|
@ -186,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
|
|||
|
||||
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
||||
// Returns an error if the execution failed.
|
||||
func (s *devicesStatements) deleteDevices(
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
||||
|
|
@ -196,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
|
|||
|
||||
// deleteDevicesByLocalpart removes all devices for the
|
||||
// given user localpart.
|
||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
|
|
@ -204,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceName(
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
|
|
@ -212,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
func (s *devicesStatements) SelectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
|
|
@ -228,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) selectDeviceByID(
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
|
|
@ -245,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
return &dev, err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -268,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
|
|
@ -310,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
const keyBackupTableSchema = `
|
||||
|
|
@ -72,12 +73,13 @@ type keyBackupStatements struct {
|
|||
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(keyBackupTableSchema)
|
||||
func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
||||
s := &keyBackupStatements{}
|
||||
_, err := db.Exec(keyBackupTableSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||
{&s.countKeysStmt, countKeysSQL},
|
||||
|
|
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s keyBackupStatements) countKeys(
|
||||
func (s keyBackupStatements) CountKeys(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (count int64, err error) {
|
||||
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) insertBackupKey(
|
||||
func (s *keyBackupStatements) InsertBackupKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
|
||||
|
|
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) updateBackupKey(
|
||||
func (s *keyBackupStatements) UpdateBackupKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
|
||||
|
|
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeys(
|
||||
func (s *keyBackupStatements) SelectKeys(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||
|
|
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
|
|||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeysByRoomID(
|
||||
func (s *keyBackupStatements) SelectKeysByRoomID(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||
|
|
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
|
|||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
||||
func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
const keyBackupVersionTableSchema = `
|
||||
|
|
@ -69,12 +70,13 @@ type keyBackupVersionStatements struct {
|
|||
updateKeyBackupETagStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(keyBackupVersionTableSchema)
|
||||
func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
|
||||
s := &keyBackupVersionStatements{}
|
||||
_, err := db.Exec(keyBackupVersionTableSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
|
||||
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
|
||||
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
|
||||
|
|
@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) insertKeyBackup(
|
||||
func (s *keyBackupVersionStatements) InsertKeyBackup(
|
||||
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
|
||||
) (version string, err error) {
|
||||
var versionInt int64
|
||||
|
|
@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
|
|||
return strconv.FormatInt(versionInt, 10), err
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
||||
func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
|
||||
) error {
|
||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||
|
|
@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
||||
func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, etag string,
|
||||
) error {
|
||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||
|
|
@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) deleteKeyBackup(
|
||||
func (s *keyBackupVersionStatements) DeleteKeyBackup(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (bool, error) {
|
||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||
|
|
@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
|
|||
return ra == 1, nil
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) selectKeyBackup(
|
||||
func (s *keyBackupVersionStatements) SelectKeyBackup(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||
var versionInt int64
|
||||
|
|
@ -21,18 +21,11 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
type loginTokenStatements struct {
|
||||
insertStmt *sql.Stmt
|
||||
deleteStmt *sql.Stmt
|
||||
selectByTokenStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// execSchema ensures tables and indices exist.
|
||||
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(`
|
||||
const loginTokenSchema = `
|
||||
CREATE TABLE IF NOT EXISTS login_tokens (
|
||||
-- The random value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
|
|
@ -45,21 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens (
|
|||
|
||||
-- This index allows efficient garbage collection of expired tokens.
|
||||
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
||||
`)
|
||||
return err
|
||||
`
|
||||
|
||||
const insertLoginTokenSQL = "" +
|
||||
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteLoginTokenSQL = "" +
|
||||
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||
|
||||
const selectLoginTokenSQL = "" +
|
||||
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||
|
||||
type loginTokenStatements struct {
|
||||
insertStmt *sql.Stmt
|
||||
deleteStmt *sql.Stmt
|
||||
selectStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// prepare runs statement preparation.
|
||||
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
||||
return sqlutil.StatementList{
|
||||
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
||||
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
||||
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
|
||||
func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
|
||||
s := &loginTokenStatements{}
|
||||
_, err := db.Exec(loginTokenSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertStmt, insertLoginTokenSQL},
|
||||
{&s.deleteStmt, deleteLoginTokenSQL},
|
||||
{&s.selectStmt, selectLoginTokenSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insert adds an already generated token to the database.
|
||||
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
||||
func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStmt)
|
||||
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
|
||||
return err
|
||||
|
|
@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
|
|||
//
|
||||
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
||||
// The login_tokens_expiration_idx index should make that efficient.
|
||||
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
||||
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
||||
if err != nil {
|
||||
|
|
@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
|
|||
}
|
||||
|
||||
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
|
||||
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
var data api.LoginTokenData
|
||||
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
||||
err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
|
@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
|
|||
);
|
||||
`
|
||||
|
||||
const insertTokenSQL = "" +
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectTokenSQL = "" +
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
|
||||
type tokenStatements struct {
|
||||
type openIDTokenStatements struct {
|
||||
insertTokenStmt *sql.Stmt
|
||||
selectTokenStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
_, err = db.Exec(openIDTokenSchema)
|
||||
if err != nil {
|
||||
return
|
||||
func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
|
||||
s := &openIDTokenStatements{
|
||||
serverName: serverName,
|
||||
}
|
||||
s.serverName = server
|
||||
return sqlutil.StatementList{
|
||||
{&s.insertTokenStmt, insertTokenSQL},
|
||||
{&s.selectTokenStmt, selectTokenSQL},
|
||||
_, err := db.Exec(openIDTokenSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertTokenStmt, insertOpenIDTokenSQL},
|
||||
{&s.selectTokenStmt, selectOpenIDTokenSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insertToken inserts a new OpenID Connect token to the DB.
|
||||
// Returns new token, otherwise returns error if the token already exists.
|
||||
func (s *tokenStatements) insertToken(
|
||||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
|
|
@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken(
|
|||
|
||||
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||
// Returns the existing token's attributes, or err if no token is found
|
||||
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||
func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
|
|
@ -59,12 +60,13 @@ type profilesStatements struct {
|
|||
selectProfilesBySearchStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(profilesSchema)
|
||||
func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
|
||||
s := &profilesStatements{}
|
||||
_, err := db.Exec(profilesSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertProfileStmt, insertProfileSQL},
|
||||
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
|
||||
{&s.setAvatarURLStmt, setAvatarURLSQL},
|
||||
|
|
@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *profilesStatements) insertProfile(
|
||||
func (s *profilesStatements) InsertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) selectProfileByLocalpart(
|
||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
|
|
@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
|||
return &profile, nil
|
||||
}
|
||||
|
||||
func (s *profilesStatements) setAvatarURL(
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
func (s *profilesStatements) SetAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
) (err error) {
|
||||
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) setDisplayName(
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
func (s *profilesStatements) SetDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
) (err error) {
|
||||
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) selectProfilesBySearch(
|
||||
func (s *profilesStatements) SelectProfilesBySearch(
|
||||
ctx context.Context, searchString string, limit int,
|
||||
) ([]authtypes.Profile, error) {
|
||||
var profiles []authtypes.Profile
|
||||
105
userapi/storage/postgres/storage.go
Normal file
105
userapi/storage/postgres/storage.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
||||
|
||||
// Import the postgres database driver.
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrations()
|
||||
if _, err = db.Exec(accountsSchema); err != nil {
|
||||
// do this so that the migration can and we don't fail on
|
||||
// preparing statements for columns that don't exist yet
|
||||
return nil, err
|
||||
}
|
||||
deltas.LoadIsActive(m)
|
||||
//deltas.LoadLastSeenTSIP(m)
|
||||
deltas.LoadAddAccountType(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewPostgresAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
||||
}
|
||||
devicesTable, err := NewPostgresDevicesTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
|
||||
}
|
||||
keyBackupTable, err := NewPostgresKeyBackupTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err)
|
||||
}
|
||||
keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err)
|
||||
}
|
||||
loginTokenTable, err := NewPostgresLoginTokenTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err)
|
||||
}
|
||||
openIDTable, err := NewPostgresOpenIDTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err)
|
||||
}
|
||||
profilesTable, err := NewPostgresProfilesTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err)
|
||||
}
|
||||
threePIDTable, err := NewPostgresThreePIDTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
|
||||
}
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
Devices: devicesTable,
|
||||
KeyBackups: keyBackupTable,
|
||||
KeyBackupVersions: keyBackupVersionTable,
|
||||
LoginTokens: loginTokenTable,
|
||||
OpenIDTokens: openIDTable,
|
||||
Profiles: profilesTable,
|
||||
ThreePIDs: threePIDTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
LoginTokenLifetime: loginTokenLifetime,
|
||||
BcryptCost: bcryptCost,
|
||||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
|
@ -58,12 +59,13 @@ type threepidStatements struct {
|
|||
deleteThreePIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(threepidSchema)
|
||||
func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||
s := &threepidStatements{}
|
||||
_, err := db.Exec(threepidSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
|
||||
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
|
||||
{&s.insertThreePIDStmt, insertThreePIDSQL},
|
||||
|
|
@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectLocalpartForThreePID(
|
||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
|
|
@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
|
|
@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) insertThreePID(
|
||||
func (s *threepidStatements) InsertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
|
|
@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) deleteThreePID(
|
||||
ctx context.Context, threepid string, medium string) (err error) {
|
||||
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
|
||||
func (s *threepidStatements) DeleteThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium)
|
||||
return
|
||||
}
|
||||
|
|
@ -12,16 +12,17 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -30,102 +31,48 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
// Database represents an account database
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
|
||||
sqlutil.PartitionOffsetStatements
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
accountDatas accountDataStatements
|
||||
threepids threepidStatements
|
||||
openIDTokens tokenStatements
|
||||
keyBackupVersions keyBackupVersionStatements
|
||||
keyBackups keyBackupStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
bcryptCost int
|
||||
openIDTokenLifetimeMS int64
|
||||
|
||||
accountsMu sync.Mutex
|
||||
profilesMu sync.Mutex
|
||||
accountDatasMu sync.Mutex
|
||||
threepidsMu sync.Mutex
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
Accounts tables.AccountsTable
|
||||
Profiles tables.ProfileTable
|
||||
AccountDatas tables.AccountDataTable
|
||||
ThreePIDs tables.ThreePIDTable
|
||||
OpenIDTokens tables.OpenIDTable
|
||||
KeyBackups tables.KeyBackupTable
|
||||
KeyBackupVersions tables.KeyBackupVersionTable
|
||||
Devices tables.DevicesTable
|
||||
LoginTokens tables.LoginTokenTable
|
||||
LoginTokenLifetime time.Duration
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
BcryptCost int
|
||||
OpenIDTokenLifetimeMS int64
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &Database{
|
||||
serverName: serverName,
|
||||
db: db,
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
bcryptCost: bcryptCost,
|
||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = d.accounts.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadIsActive(m)
|
||||
deltas.LoadAddAccountType(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
partitions := sqlutil.PartitionOffsetStatements{}
|
||||
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accounts.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.profiles.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accountDatas.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.threepids.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.keyBackupVersions.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.keyBackups.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
const (
|
||||
// The length of generated device IDs
|
||||
deviceIDByteLength = 6
|
||||
loginTokenByteLength = 32
|
||||
)
|
||||
|
||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
) (*api.Account, error) {
|
||||
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
|
|
@ -133,7 +80,7 @@ func (d *Database) GetAccountByPassword(
|
|||
func (d *Database) GetProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*authtypes.Profile, error) {
|
||||
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||
|
|
@ -141,10 +88,8 @@ func (d *Database) GetProfileByLocalpart(
|
|||
func (d *Database) SetAvatarURL(
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
) error {
|
||||
d.profilesMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -153,10 +98,8 @@ func (d *Database) SetAvatarURL(
|
|||
func (d *Database) SetDisplayName(
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
) error {
|
||||
d.profilesMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -168,8 +111,8 @@ func (d *Database) SetPassword(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.accounts.updatePassword(ctx, localpart, hash)
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, hash)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -179,18 +122,11 @@ func (d *Database) SetPassword(
|
|||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (acc *api.Account, err error) {
|
||||
// Create one account at a time else we can get 'database is locked'.
|
||||
d.profilesMu.Lock()
|
||||
d.accountDatasMu.Lock()
|
||||
d.accountsMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
defer d.accountDatasMu.Unlock()
|
||||
defer d.accountsMu.Unlock()
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// For guest accounts, we create a new numeric local part
|
||||
if accountType == api.AccountTypeGuest {
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -219,18 +155,18 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.serverName)
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
|
|
@ -244,10 +180,8 @@ func (d *Database) createAccount(
|
|||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
d.accountDatasMu.Lock()
|
||||
defer d.accountDatasMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -259,7 +193,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
|
|
@ -269,7 +203,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.accountDatas.selectAccountDataByType(
|
||||
return d.AccountDatas.SelectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
)
|
||||
}
|
||||
|
|
@ -278,11 +212,11 @@ func (d *Database) GetAccountDataByType(
|
|||
func (d *Database) GetNewNumericLocalpart(
|
||||
ctx context.Context,
|
||||
) (int64, error) {
|
||||
return d.accounts.selectNewNumericLocalpart(ctx, nil)
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
|
||||
}
|
||||
|
||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
|
||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost)
|
||||
return string(hashBytes), err
|
||||
}
|
||||
|
||||
|
|
@ -297,10 +231,8 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
|||
func (d *Database) SaveThreePIDAssociation(
|
||||
ctx context.Context, threepid, localpart, medium string,
|
||||
) (err error) {
|
||||
d.threepidsMu.Lock()
|
||||
defer d.threepidsMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
user, err := d.threepids.selectLocalpartForThreePID(
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||
ctx, txn, threepid, medium,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -311,7 +243,7 @@ func (d *Database) SaveThreePIDAssociation(
|
|||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -322,10 +254,8 @@ func (d *Database) SaveThreePIDAssociation(
|
|||
func (d *Database) RemoveThreePIDAssociation(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (err error) {
|
||||
d.threepidsMu.Lock()
|
||||
defer d.threepidsMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -337,7 +267,7 @@ func (d *Database) RemoveThreePIDAssociation(
|
|||
func (d *Database) GetLocalpartForThreePID(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||
}
|
||||
|
||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||
|
|
@ -347,14 +277,14 @@ func (d *Database) GetLocalpartForThreePID(
|
|||
func (d *Database) GetThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// CheckAccountAvailability checks if the username/localpart is already present
|
||||
// in the database.
|
||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
if err == sql.ErrNoRows {
|
||||
return true, nil
|
||||
}
|
||||
|
|
@ -366,20 +296,20 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
|
|||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// SearchProfiles returns all profiles where the provided localpart or display name
|
||||
// match any part of the profiles in the database.
|
||||
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
|
||||
) ([]authtypes.Profile, error) {
|
||||
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
|
||||
return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit)
|
||||
}
|
||||
|
||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||
return d.writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.accounts.deactivateAccount(ctx, localpart)
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -388,9 +318,9 @@ func (d *Database) CreateOpenIDToken(
|
|||
ctx context.Context,
|
||||
token, localpart string,
|
||||
) (int64, error) {
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
|
@ -400,14 +330,14 @@ func (d *Database) GetOpenIDTokenAttributes(
|
|||
ctx context.Context,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
||||
return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token)
|
||||
}
|
||||
|
||||
func (d *Database) CreateKeyBackup(
|
||||
ctx context.Context, userID, algorithm string, authData json.RawMessage,
|
||||
) (version string, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "")
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
@ -416,8 +346,8 @@ func (d *Database) CreateKeyBackup(
|
|||
func (d *Database) UpdateKeyBackupAuthData(
|
||||
ctx context.Context, userID, version string, authData json.RawMessage,
|
||||
) (err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
@ -425,8 +355,8 @@ func (d *Database) UpdateKeyBackupAuthData(
|
|||
func (d *Database) DeleteKeyBackup(
|
||||
ctx context.Context, userID, version string,
|
||||
) (exists bool, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
@ -435,8 +365,8 @@ func (d *Database) DeleteKeyBackup(
|
|||
func (d *Database) GetKeyBackup(
|
||||
ctx context.Context, userID, version string,
|
||||
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
@ -445,16 +375,16 @@ func (d *Database) GetKeyBackup(
|
|||
func (d *Database) GetBackupKeys(
|
||||
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
||||
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if filterSessionID != "" {
|
||||
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
||||
result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
||||
return err
|
||||
}
|
||||
if filterRoomID != "" {
|
||||
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
||||
result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
||||
return err
|
||||
}
|
||||
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
|
||||
result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
@ -463,8 +393,8 @@ func (d *Database) GetBackupKeys(
|
|||
func (d *Database) CountBackupKeys(
|
||||
ctx context.Context, version, userID string,
|
||||
) (count int64, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -478,8 +408,8 @@ func (d *Database) UpsertBackupKeys(
|
|||
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
||||
) (count int64, etag string, err error) {
|
||||
// wrap the following logic in a txn to ensure we atomically upload keys
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
_, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -487,7 +417,7 @@ func (d *Database) UpsertBackupKeys(
|
|||
return fmt.Errorf("backup was deleted")
|
||||
}
|
||||
// pull out all keys for this (user_id, version)
|
||||
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
|
||||
existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -501,10 +431,10 @@ func (d *Database) UpsertBackupKeys(
|
|||
existingSession, ok := existingRoom[newKey.SessionID]
|
||||
if ok {
|
||||
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
||||
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
||||
err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey)
|
||||
changed = true
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
|
||||
return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err)
|
||||
}
|
||||
}
|
||||
// if we shouldn't replace the key we do nothing with it
|
||||
|
|
@ -512,14 +442,14 @@ func (d *Database) UpsertBackupKeys(
|
|||
}
|
||||
}
|
||||
// if we're here, either the room or session are new, either way, we insert
|
||||
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
|
||||
err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey)
|
||||
changed = true
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
|
||||
return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
||||
count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -536,7 +466,7 @@ func (d *Database) UpsertBackupKeys(
|
|||
newETag = strconv.FormatInt(oldETagInt+1, 10)
|
||||
}
|
||||
etag = newETag
|
||||
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
|
||||
return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag)
|
||||
} else {
|
||||
etag = oldETag
|
||||
}
|
||||
|
|
@ -545,3 +475,196 @@ func (d *Database) UpsertBackupKeys(
|
|||
})
|
||||
return
|
||||
}
|
||||
|
||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByAccessToken(
|
||||
ctx context.Context, token string,
|
||||
) (*api.Device, error) {
|
||||
return d.Devices.SelectDeviceByToken(ctx, token)
|
||||
}
|
||||
|
||||
// GetDeviceByID returns the device matching the given ID.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]api.Device, error) {
|
||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
return d.Devices.SelectDevicesByID(ctx, deviceIDs)
|
||||
}
|
||||
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
// We generate device IDs in a loop in case its already taken.
|
||||
// We cap this at going round 5 times to ensure we don't spin forever
|
||||
var newDeviceID string
|
||||
for i := 1; i <= 5; i++ {
|
||||
newDeviceID, returnErr = generateDeviceID()
|
||||
if returnErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||
// random bytes.
|
||||
func generateDeviceID() (string, error) {
|
||||
b := make([]byte, deviceIDByteLength)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// url-safe no padding
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// UpdateDevice updates the given device with the display name.
|
||||
// Returns SQL error if there are problems and nil on success.
|
||||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevice revokes a device by deleting the entry in the database
|
||||
// matching with the given device ID and user ID localpart.
|
||||
// If the device doesn't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevice(
|
||||
ctx context.Context, deviceID, localpart string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||
// matching with the given device IDs and user ID localpart.
|
||||
// If the devices don't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevices(
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||
// database matching the given user ID localpart.
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveAllDevices(
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
) (devices []api.Device, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
||||
tok, err := generateLoginToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
meta := &api.LoginTokenMetadata{
|
||||
Token: tok,
|
||||
Expiration: time.Now().Add(d.LoginTokenLifetime),
|
||||
}
|
||||
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func generateLoginToken() (string, error) {
|
||||
b := make([]byte, loginTokenByteLength)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.LoginTokens.DeleteLoginToken(ctx, txn, token)
|
||||
})
|
||||
}
|
||||
|
||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||
// May return sql.ErrNoRows.
|
||||
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
return d.LoginTokens.SelectLoginToken(ctx, token)
|
||||
}
|
||||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
|
@ -56,27 +57,29 @@ type accountDataStatements struct {
|
|||
selectAccountDataByTypeStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(accountDataSchema)
|
||||
if err != nil {
|
||||
return
|
||||
func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
||||
s := &accountDataStatements{
|
||||
db: db,
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
_, err := db.Exec(accountDataSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertAccountDataStmt, insertAccountDataSQL},
|
||||
{&s.selectAccountDataStmt, selectAccountDataSQL},
|
||||
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) insertAccountData(
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) selectAccountData(
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
|
|
@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
|
|||
return global, rooms, nil
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) selectAccountDataByType(
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
|
@ -77,15 +78,16 @@ type accountsStatements struct {
|
|||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *accountsStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(accountsSchema)
|
||||
return err
|
||||
func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
|
||||
s := &accountsStatements{
|
||||
db: db,
|
||||
serverName: serverName,
|
||||
}
|
||||
|
||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
s.serverName = server
|
||||
return sqlutil.StatementList{
|
||||
_, err := db.Exec(accountsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertAccountStmt, insertAccountSQL},
|
||||
{&s.updatePasswordStmt, updatePasswordSQL},
|
||||
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
||||
|
|
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
|
|
@ -122,28 +124,28 @@ func (s *accountsStatements) insertAccount(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) updatePassword(
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) deactivateAccount(
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectPasswordHash(
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectAccountByLocalpart(
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
|
|
@ -167,7 +169,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectNewNumericLocalpart(
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
|
|
@ -5,13 +5,8 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/pressly/goose"
|
||||
)
|
||||
|
||||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
||||
}
|
||||
|
||||
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
||||
}
|
||||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" +
|
|||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDevicesCountStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
|
|
@ -98,52 +98,33 @@ type devicesStatements struct {
|
|||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
||||
func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
|
||||
s := &devicesStatements{
|
||||
db: db,
|
||||
serverName: serverName,
|
||||
}
|
||||
_, err := db.Exec(devicesSchema)
|
||||
return err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertDeviceStmt, insertDeviceSQL},
|
||||
{&s.selectDevicesCountStmt, selectDevicesCountSQL},
|
||||
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
|
||||
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
|
||||
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
|
||||
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
|
||||
{&s.deleteDeviceStmt, deleteDeviceSQL},
|
||||
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
|
||||
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
|
||||
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) insertDevice(
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
|
|
@ -169,7 +150,7 @@ func (s *devicesStatements) insertDevice(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevice(
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
|
|
@ -177,7 +158,7 @@ func (s *devicesStatements) deleteDevice(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevices(
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
) error {
|
||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||
|
|
@ -195,7 +176,7 @@ func (s *devicesStatements) deleteDevices(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
|
|
@ -203,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceName(
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
|
|
@ -211,7 +192,7 @@ func (s *devicesStatements) updateDeviceName(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
func (s *devicesStatements) SelectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
|
|
@ -227,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) selectDeviceByID(
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
|
|
@ -244,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
return &dev, err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
|
|
@ -285,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
return devices, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
||||
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
||||
for i := range deviceIDs {
|
||||
|
|
@ -314,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
const keyBackupTableSchema = `
|
||||
|
|
@ -72,12 +73,13 @@ type keyBackupStatements struct {
|
|||
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(keyBackupTableSchema)
|
||||
func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
||||
s := &keyBackupStatements{}
|
||||
_, err := db.Exec(keyBackupTableSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||
{&s.countKeysStmt, countKeysSQL},
|
||||
|
|
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s keyBackupStatements) countKeys(
|
||||
func (s keyBackupStatements) CountKeys(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (count int64, err error) {
|
||||
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) insertBackupKey(
|
||||
func (s *keyBackupStatements) InsertBackupKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
|
||||
|
|
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) updateBackupKey(
|
||||
func (s *keyBackupStatements) UpdateBackupKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
|
||||
|
|
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeys(
|
||||
func (s *keyBackupStatements) SelectKeys(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||
|
|
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
|
|||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeysByRoomID(
|
||||
func (s *keyBackupStatements) SelectKeysByRoomID(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||
|
|
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
|
|||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
||||
func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
const keyBackupVersionTableSchema = `
|
||||
|
|
@ -67,12 +68,13 @@ type keyBackupVersionStatements struct {
|
|||
updateKeyBackupETagStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(keyBackupVersionTableSchema)
|
||||
func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
|
||||
s := &keyBackupVersionStatements{}
|
||||
_, err := db.Exec(keyBackupVersionTableSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
|
||||
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
|
||||
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
|
||||
|
|
@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) insertKeyBackup(
|
||||
func (s *keyBackupVersionStatements) InsertKeyBackup(
|
||||
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
|
||||
) (version string, err error) {
|
||||
var versionInt int64
|
||||
|
|
@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
|
|||
return strconv.FormatInt(versionInt, 10), err
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
||||
func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
|
||||
) error {
|
||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||
|
|
@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
||||
func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, etag string,
|
||||
) error {
|
||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||
|
|
@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) deleteKeyBackup(
|
||||
func (s *keyBackupVersionStatements) DeleteKeyBackup(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (bool, error) {
|
||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||
|
|
@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
|
|||
return ra == 1, nil
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) selectKeyBackup(
|
||||
func (s *keyBackupVersionStatements) SelectKeyBackup(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||
var versionInt int64
|
||||
|
|
@ -21,18 +21,17 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
type loginTokenStatements struct {
|
||||
insertStmt *sql.Stmt
|
||||
deleteStmt *sql.Stmt
|
||||
selectByTokenStmt *sql.Stmt
|
||||
selectStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// execSchema ensures tables and indices exist.
|
||||
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(`
|
||||
const loginTokenSchema = `
|
||||
CREATE TABLE IF NOT EXISTS login_tokens (
|
||||
-- The random value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
|
|
@ -45,21 +44,32 @@ CREATE TABLE IF NOT EXISTS login_tokens (
|
|||
|
||||
-- This index allows efficient garbage collection of expired tokens.
|
||||
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
||||
`)
|
||||
return err
|
||||
}
|
||||
`
|
||||
|
||||
// prepare runs statement preparation.
|
||||
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
||||
return sqlutil.StatementList{
|
||||
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
||||
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
||||
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
|
||||
const insertLoginTokenSQL = "" +
|
||||
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteLoginTokenSQL = "" +
|
||||
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||
|
||||
const selectLoginTokenSQL = "" +
|
||||
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||
|
||||
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
|
||||
s := &loginTokenStatements{}
|
||||
_, err := db.Exec(loginTokenSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertStmt, insertLoginTokenSQL},
|
||||
{&s.deleteStmt, deleteLoginTokenSQL},
|
||||
{&s.selectStmt, selectLoginTokenSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insert adds an already generated token to the database.
|
||||
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
||||
func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStmt)
|
||||
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
|
||||
return err
|
||||
|
|
@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
|
|||
//
|
||||
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
||||
// The login_tokens_expiration_idx index should make that efficient.
|
||||
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
||||
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
||||
if err != nil {
|
||||
|
|
@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
|
|||
}
|
||||
|
||||
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
|
||||
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
var data api.LoginTokenData
|
||||
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
||||
err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
|
@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
|
|||
);
|
||||
`
|
||||
|
||||
const insertTokenSQL = "" +
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectTokenSQL = "" +
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
|
||||
type tokenStatements struct {
|
||||
type openIDTokenStatements struct {
|
||||
db *sql.DB
|
||||
insertTokenStmt *sql.Stmt
|
||||
selectTokenStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(openIDTokenSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
|
||||
s := &openIDTokenStatements{
|
||||
db: db,
|
||||
serverName: serverName,
|
||||
}
|
||||
s.serverName = server
|
||||
return sqlutil.StatementList{
|
||||
{&s.insertTokenStmt, insertTokenSQL},
|
||||
{&s.selectTokenStmt, selectTokenSQL},
|
||||
_, err := db.Exec(openIDTokenSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertTokenStmt, insertOpenIDTokenSQL},
|
||||
{&s.selectTokenStmt, selectOpenIDTokenSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insertToken inserts a new OpenID Connect token to the DB.
|
||||
// Returns new token, otherwise returns error if the token already exists.
|
||||
func (s *tokenStatements) insertToken(
|
||||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
|
|
@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken(
|
|||
|
||||
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||
// Returns the existing token's attributes, or err if no token is found
|
||||
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||
func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
|
|
@ -60,13 +61,15 @@ type profilesStatements struct {
|
|||
selectProfilesBySearchStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(profilesSchema)
|
||||
if err != nil {
|
||||
return
|
||||
func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
|
||||
s := &profilesStatements{
|
||||
db: db,
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
_, err := db.Exec(profilesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertProfileStmt, insertProfileSQL},
|
||||
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
|
||||
{&s.setAvatarURLStmt, setAvatarURLSQL},
|
||||
|
|
@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *profilesStatements) insertProfile(
|
||||
func (s *profilesStatements) InsertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) selectProfileByLocalpart(
|
||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
|
|
@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
|||
return &profile, nil
|
||||
}
|
||||
|
||||
func (s *profilesStatements) setAvatarURL(
|
||||
func (s *profilesStatements) SetAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||
|
|
@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) setDisplayName(
|
||||
func (s *profilesStatements) SetDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||
|
|
@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) selectProfilesBySearch(
|
||||
func (s *profilesStatements) SelectProfilesBySearch(
|
||||
ctx context.Context, searchString string, limit int,
|
||||
) ([]authtypes.Profile, error) {
|
||||
var profiles []authtypes.Profile
|
||||
106
userapi/storage/sqlite3/storage.go
Normal file
106
userapi/storage/sqlite3/storage.go
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
|
||||
// Import the postgres database driver.
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrations()
|
||||
if _, err = db.Exec(accountsSchema); err != nil {
|
||||
// do this so that the migration can and we don't fail on
|
||||
// preparing statements for columns that don't exist yet
|
||||
return nil, err
|
||||
}
|
||||
deltas.LoadIsActive(m)
|
||||
//deltas.LoadLastSeenTSIP(m)
|
||||
deltas.LoadAddAccountType(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
||||
}
|
||||
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
|
||||
}
|
||||
keyBackupTable, err := NewSQLiteKeyBackupTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err)
|
||||
}
|
||||
keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err)
|
||||
}
|
||||
loginTokenTable, err := NewSQLiteLoginTokenTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err)
|
||||
}
|
||||
openIDTable, err := NewSQLiteOpenIDTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err)
|
||||
}
|
||||
profilesTable, err := NewSQLiteProfilesTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err)
|
||||
}
|
||||
threePIDTable, err := NewSQLiteThreePIDTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
|
||||
}
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
Devices: devicesTable,
|
||||
KeyBackups: keyBackupTable,
|
||||
KeyBackupVersions: keyBackupVersionTable,
|
||||
LoginTokens: loginTokenTable,
|
||||
OpenIDTokens: openIDTable,
|
||||
Profiles: profilesTable,
|
||||
ThreePIDs: threePIDTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
LoginTokenLifetime: loginTokenLifetime,
|
||||
BcryptCost: bcryptCost,
|
||||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -20,6 +20,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
|
@ -60,13 +61,15 @@ type threepidStatements struct {
|
|||
deleteThreePIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(threepidSchema)
|
||||
if err != nil {
|
||||
return
|
||||
func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||
s := &threepidStatements{
|
||||
db: db,
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
_, err := db.Exec(threepidSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
|
||||
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
|
||||
{&s.insertThreePIDStmt, insertThreePIDSQL},
|
||||
|
|
@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectLocalpartForThreePID(
|
||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
|
|
@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
|
|
@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
|||
return threepids, rows.Err()
|
||||
}
|
||||
|
||||
func (s *threepidStatements) insertThreePID(
|
||||
func (s *threepidStatements) InsertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
|
|
@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *threepidStatements) deleteThreePID(
|
||||
func (s *threepidStatements) DeleteThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium)
|
||||
|
|
@ -15,26 +15,27 @@
|
|||
//go:build !wasm
|
||||
// +build !wasm
|
||||
|
||||
package accounts
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// and sets postgres connection parameters
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
|
|
@ -12,13 +12,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package accounts
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
|
@ -27,10 +28,11 @@ func NewDatabase(
|
|||
serverName gomatrixserverlib.ServerName,
|
||||
bcryptCost int,
|
||||
openIDTokenLifetimeMS int64,
|
||||
loginTokenLifetime time.Duration,
|
||||
) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
95
userapi/storage/tables/interface.go
Normal file
95
userapi/storage/tables/interface.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
type AccountDataTable interface {
|
||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
}
|
||||
|
||||
type AccountsTable interface {
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||
}
|
||||
|
||||
type DevicesTable interface {
|
||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
||||
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
|
||||
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
|
||||
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
|
||||
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error
|
||||
}
|
||||
|
||||
type KeyBackupTable interface {
|
||||
CountKeys(ctx context.Context, txn *sql.Tx, userID, version string) (count int64, err error)
|
||||
InsertBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error)
|
||||
UpdateBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error)
|
||||
SelectKeys(ctx context.Context, txn *sql.Tx, userID, version string) (map[string]map[string]api.KeyBackupSession, error)
|
||||
SelectKeysByRoomID(ctx context.Context, txn *sql.Tx, userID, version, roomID string) (map[string]map[string]api.KeyBackupSession, error)
|
||||
SelectKeysByRoomIDAndSessionID(ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string) (map[string]map[string]api.KeyBackupSession, error)
|
||||
}
|
||||
|
||||
type KeyBackupVersionTable interface {
|
||||
InsertKeyBackup(ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string) (version string, err error)
|
||||
UpdateKeyBackupAuthData(ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage) error
|
||||
UpdateKeyBackupETag(ctx context.Context, txn *sql.Tx, userID, version, etag string) error
|
||||
DeleteKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (bool, error)
|
||||
SelectKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||
}
|
||||
|
||||
type LoginTokenTable interface {
|
||||
InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error
|
||||
DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error
|
||||
SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
||||
}
|
||||
|
||||
type OpenIDTable interface {
|
||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
|
||||
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
}
|
||||
|
||||
type ProfileTable interface {
|
||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
||||
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
|
||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
|
||||
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||
}
|
||||
|
||||
type ThreePIDTable interface {
|
||||
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
|
||||
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
|
||||
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
||||
}
|
||||
|
|
@ -23,18 +23,10 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// defaultLoginTokenLifetime determines how old a valid token may be.
|
||||
//
|
||||
// NOTSPEC: The current spec says "SHOULD be limited to around five
|
||||
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
|
||||
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
|
||||
const defaultLoginTokenLifetime = 2 * time.Minute
|
||||
|
||||
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
|
||||
// on the given input API.
|
||||
func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
|
||||
|
|
@ -44,26 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
|
|||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||
func NewInternalAPI(
|
||||
accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
|
||||
accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
|
||||
) api.UserInternalAPI {
|
||||
deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime)
|
||||
db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to device db")
|
||||
}
|
||||
|
||||
return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI)
|
||||
return newInternalAPI(db, cfg, appServices, keyAPI)
|
||||
}
|
||||
|
||||
func newInternalAPI(
|
||||
accountDB accounts.Database,
|
||||
deviceDB devices.Database,
|
||||
db storage.Database,
|
||||
cfg *config.UserAPI,
|
||||
appServices []config.ApplicationService,
|
||||
keyAPI keyapi.KeyInternalAPI,
|
||||
) api.UserInternalAPI {
|
||||
return &internal.UserInternalAPI{
|
||||
AccountDB: accountDB,
|
||||
DeviceDB: deviceDB,
|
||||
DB: db,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
AppServices: appServices,
|
||||
KeyAPI: keyAPI,
|
||||
|
|
|
|||
|
|
@ -31,8 +31,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -43,23 +42,19 @@ type apiTestOpts struct {
|
|||
loginTokenLifetime time.Duration
|
||||
}
|
||||
|
||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) {
|
||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) {
|
||||
if opts.loginTokenLifetime == 0 {
|
||||
opts.loginTokenLifetime = defaultLoginTokenLifetime
|
||||
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
|
||||
}
|
||||
dbopts := &config.DatabaseOptions{
|
||||
ConnectionString: "file::memory:",
|
||||
MaxOpenConnections: 1,
|
||||
MaxIdleConnections: 1,
|
||||
}
|
||||
accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
|
||||
accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account DB: %s", err)
|
||||
}
|
||||
deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create device DB: %s", err)
|
||||
}
|
||||
|
||||
cfg := &config.UserAPI{
|
||||
Matrix: &config.Global{
|
||||
|
|
@ -67,7 +62,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, a
|
|||
},
|
||||
}
|
||||
|
||||
return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB
|
||||
return newInternalAPI(accountDB, cfg, nil, nil), accountDB
|
||||
}
|
||||
|
||||
func TestQueryProfile(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue