Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/context

This commit is contained in:
Till Faelligen 2022-02-18 15:16:10 +01:00
commit b0e10ec75a
103 changed files with 1445 additions and 2130 deletions

View file

@ -23,7 +23,7 @@ import (
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@ -85,7 +85,7 @@ func RetrieveUserProfile(
ctx context.Context,
userID string,
asAPI AppServiceQueryAPI,
accountDB accounts.Database,
accountDB userdb.Database,
) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {

View file

@ -283,8 +283,7 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))

View file

@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))

View file

@ -28,7 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@ -37,7 +37,7 @@ func AddPublicRoutes(
router *mux.Router,
synapseAdminRouter *mux.Router,
cfg *config.ClientAPI,
accountsDB accounts.Database,
accountsDB userdb.Database,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI,
eduInputAPI eduServerAPI.EDUServerInputAPI,

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
@ -137,7 +137,7 @@ type fledglingEvent struct {
func CreateRoom(
req *http.Request, device *api.Device,
cfg *config.ClientAPI,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
@ -151,7 +151,7 @@ func CreateRoom(
func createRoom(
req *http.Request, device *api.Device,
cfg *config.ClientAPI, roomID string,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
roomIDOrAlias string,
) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join.

View file

@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util"
)
@ -36,7 +36,7 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive,
keyserverAPI api.KeyInternalAPI, device *userapi.Device,
accountDB accounts.Database, cfg *config.ClientAPI,
accountDB userdb.Database, cfg *config.ClientAPI,
) util.JSONResponse {
uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{}

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -54,7 +54,7 @@ func passwordLogin() flows {
// Login implements GET and POST /login
func Login(
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
if req.Method == http.MethodGet {

View file

@ -30,7 +30,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -39,7 +39,7 @@ import (
var errMissingUserID = errors.New("'user_id' must be supplied")
func SendBan(
req *http.Request, accountDB accounts.Database, device *userapi.Device,
req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@ -81,7 +81,7 @@ func SendBan(
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
}
func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device,
func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device,
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
roomVer gomatrixserverlib.RoomVersion,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
}
func SendKick(
req *http.Request, accountDB accounts.Database, device *userapi.Device,
req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@ -165,7 +165,7 @@ func SendKick(
}
func SendUnban(
req *http.Request, accountDB accounts.Database, device *userapi.Device,
req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@ -200,7 +200,7 @@ func SendUnban(
}
func SendInvite(
req *http.Request, accountDB accounts.Database, device *userapi.Device,
req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@ -271,7 +271,7 @@ func SendInvite(
func buildMembershipEvent(
ctx context.Context,
targetUserID, reason string, accountDB accounts.Database,
targetUserID, reason string, accountDB userdb.Database,
device *userapi.Device,
membership, roomID string, isDirect bool,
cfg *config.ClientAPI, evTime time.Time,
@ -312,7 +312,7 @@ func loadProfile(
ctx context.Context,
userID string,
cfg *config.ClientAPI,
accountDB accounts.Database,
accountDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -366,7 +366,7 @@ func checkAndProcessThreepid(
body *threepid.MembershipRequest,
cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
roomID string,
evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) {

View file

@ -9,7 +9,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -29,7 +29,7 @@ type newPasswordAuth struct {
func Password(
req *http.Request,
userAPI api.UserInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {

View file

@ -19,7 +19,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
roomIDOrAlias string,
) util.JSONResponse {
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
@ -82,7 +82,7 @@ func UnpeekRoomByID(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
roomID string,
) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{

View file

@ -27,7 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix"
@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID}
func GetProfile(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
@ -65,7 +65,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@ -92,7 +92,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL(
req *http.Request, accountDB accounts.Database,
req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse {
if userID != device.UserID {
@ -182,7 +182,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@ -209,7 +209,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName(
req *http.Request, accountDB accounts.Database,
req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse {
if userID != device.UserID {
@ -302,7 +302,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically
// eventutil.ErrProfileNoExists when the profile doesn't exist.
func getProfile(
ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI,
ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,

View file

@ -44,7 +44,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
)
var (
@ -448,7 +448,7 @@ func validateApplicationService(
func Register(
req *http.Request,
userAPI userapi.UserInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
cfg *config.ClientAPI,
) util.JSONResponse {
var r registerRequest
@ -532,6 +532,13 @@ func handleGuestRegistration(
cfg *config.ClientAPI,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
if cfg.RegistrationDisabled || cfg.GuestsDisabled {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Guest registration is disabled"),
}
}
var res userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeGuest,
@ -892,7 +899,7 @@ type availableResponse struct {
func RegisterAvailable(
req *http.Request,
cfg *config.ClientAPI,
accountDB accounts.Database,
accountDB userdb.Database,
) util.JSONResponse {
username := req.URL.Query().Get("username")

View file

@ -34,7 +34,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@ -51,7 +51,7 @@ func Setup(
eduAPI eduServerAPI.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
accountDB accounts.Database,
accountDB userdb.Database,
userAPI userapi.UserInternalAPI,
federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer,
@ -117,15 +117,22 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
}
r0mux := publicAPIMux.PathPrefix("/r0").Subrouter()
// You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables.
// So make a named path variable called 'apiversion' (which we will never read in handlers) and then do
// (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group
// from a match (gorilla/mux exposes no way to do this) so it demands you make it a non-capturing group
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
r0mux.Handle("/createRoom",
v3mux.Handle("/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/join/{roomIDOrAlias}",
v3mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -141,7 +148,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions)
if mscCfg.Enabled("msc2753") {
r0mux.Handle("/peek/{roomIDOrAlias}",
v3mux.Handle("/peek/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -156,12 +163,12 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
}
r0mux.Handle("/joined_rooms",
v3mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, rsAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/join",
v3mux.Handle("/rooms/{roomID}/join",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -175,7 +182,7 @@ func Setup(
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/leave",
v3mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -189,7 +196,7 @@ func Setup(
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/unpeek",
v3mux.Handle("/rooms/{roomID}/unpeek",
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -200,7 +207,7 @@ func Setup(
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/ban",
v3mux.Handle("/rooms/{roomID}/ban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -209,7 +216,7 @@ func Setup(
return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/invite",
v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -221,7 +228,7 @@ func Setup(
return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/kick",
v3mux.Handle("/rooms/{roomID}/kick",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -230,7 +237,7 @@ func Setup(
return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/unban",
v3mux.Handle("/rooms/{roomID}/unban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -239,7 +246,7 @@ func Setup(
return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}",
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -248,7 +255,7 @@ func Setup(
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -259,7 +266,7 @@ func Setup(
nil, cfg, rsAPI, transactionsCache)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/event/{eventID}",
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -269,7 +276,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -277,7 +284,7 @@ func Setup(
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -285,7 +292,7 @@ func Setup(
return GetAliases(req, rsAPI, device, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -296,7 +303,7 @@ func Setup(
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -305,7 +312,7 @@ func Setup(
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -317,7 +324,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -328,21 +335,21 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
return Register(req, userAPI, accountDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
return RegisterAvailable(req, cfg, accountDB)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -352,7 +359,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -362,7 +369,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -371,7 +378,7 @@ func Setup(
return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI)
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/directory/list/room/{roomID}",
v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -381,7 +388,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
// TODO: Add AS support
r0mux.Handle("/directory/list/room/{roomID}",
v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -390,25 +397,25 @@ func Setup(
return SetVisibility(req, rsAPI, device, vars["roomID"])
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/publicRooms",
v3mux.Handle("/publicRooms",
httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout",
v3mux.Handle("/logout",
httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Logout(req, userAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout/all",
v3mux.Handle("/logout/all",
httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return LogoutAll(req, userAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/typing/{userID}",
v3mux.Handle("/rooms/{roomID}/typing/{userID}",
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -420,7 +427,7 @@ func Setup(
return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}",
v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -429,7 +436,7 @@ func Setup(
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -439,7 +446,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/sendToDevice/{eventType}/{txnID}",
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -464,7 +471,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/account/whoami",
v3mux.Handle("/account/whoami",
httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -473,7 +480,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/password",
v3mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -482,7 +489,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/account/deactivate",
v3mux.Handle("/account/deactivate",
httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -493,7 +500,7 @@ func Setup(
// Stub endpoints required by Element
r0mux.Handle("/login",
v3mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -502,14 +509,14 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/auth/{authType}/fallback/web",
v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/pushrules/",
v3mux.Handle("/pushrules/",
httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
// TODO: Implement push rules API
res := json.RawMessage(`{
@ -530,7 +537,7 @@ func Setup(
// Element user settings
r0mux.Handle("/profile/{userID}",
v3mux.Handle("/profile/{userID}",
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -540,7 +547,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url",
v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -550,7 +557,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url",
v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -565,7 +572,7 @@ func Setup(
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method
r0mux.Handle("/profile/{userID}/displayname",
v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -575,7 +582,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/displayname",
v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -590,13 +597,13 @@ func Setup(
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method
r0mux.Handle("/account/3pid",
v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAssociated3PIDs(req, accountDB, device)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/3pid",
v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, accountDB, device, cfg)
}),
@ -608,14 +615,14 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
return RequestEmailToken(req, accountDB, cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Element logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status",
v3mux.Handle("/presence/{userID}/status",
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -628,7 +635,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/voip/turnServer",
v3mux.Handle("/voip/turnServer",
httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -637,7 +644,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/thirdparty/protocols",
v3mux.Handle("/thirdparty/protocols",
httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
// TODO: Return the third party protcols
return util.JSONResponse{
@ -647,7 +654,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/initialSync",
v3mux.Handle("/rooms/{roomID}/initialSync",
httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse {
// TODO: Allow people to peek into rooms.
return util.JSONResponse{
@ -657,7 +664,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}",
v3mux.Handle("/user/{userID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -667,7 +674,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -677,7 +684,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}",
v3mux.Handle("/user/{userID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -687,7 +694,7 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -697,7 +704,7 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/admin/whois/{userID}",
v3mux.Handle("/admin/whois/{userID}",
httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -707,7 +714,7 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/openid/request_token",
v3mux.Handle("/user/{userID}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -720,7 +727,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user_directory/search",
v3mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -745,7 +752,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/members",
v3mux.Handle("/rooms/{roomID}/members",
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -755,7 +762,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/joined_members",
v3mux.Handle("/rooms/{roomID}/joined_members",
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -765,7 +772,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/read_markers",
v3mux.Handle("/rooms/{roomID}/read_markers",
httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -778,7 +785,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/forget",
v3mux.Handle("/rooms/{roomID}/forget",
httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -791,13 +798,13 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/devices",
v3mux.Handle("/devices",
httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, userAPI, device)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -807,7 +814,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -817,7 +824,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -827,14 +834,14 @@ func Setup(
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/delete_devices",
v3mux.Handle("/delete_devices",
httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return DeleteDevices(req, userAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest
r0mux.Handle("/events",
v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"chunk": []interface{}{},
@ -844,7 +851,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/initialSync",
v3mux.Handle("/initialSync",
httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"end": "",
@ -852,7 +859,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags",
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags",
httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -862,7 +869,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -872,7 +879,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -882,7 +889,7 @@ func Setup(
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/capabilities",
v3mux.Handle("/capabilities",
httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -925,11 +932,11 @@ func Setup(
return CreateKeyBackupVersion(req, userAPI, device)
})
r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut)
r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete)
r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut)
v3mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete)
v3mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
@ -1021,9 +1028,9 @@ func Setup(
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
})
r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut)
v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut)
unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
@ -1051,9 +1058,9 @@ func Setup(
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
})
r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
@ -1071,34 +1078,34 @@ func Setup(
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
})
r0mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
// Supplying a device ID is deprecated.
r0mux.Handle("/keys/upload/{deviceID}",
v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/upload",
v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/query",
v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, keyAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/claim",
v3mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return ClaimKeys(req, keyAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r

View file

@ -20,7 +20,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util"
)
@ -33,7 +33,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer
func SendTyping(
req *http.Request, device *userapi.Device, roomID string,
userID string, accountDB accounts.Database,
userID string, accountDB userdb.Database,
eduAPI api.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
) util.JSONResponse {

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -40,7 +40,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements:
// POST /account/3pid/email/requestToken
// POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse {
func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse {
var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr
@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_IN_USE",
Err: accounts.Err3PIDInUse.Error(),
Err: userdb.Err3PIDInUse.Error(),
},
}
}
@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
// CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation(
req *http.Request, accountDB accounts.Database, device *api.Device,
req *http.Request, accountDB userdb.Database, device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest
@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs(
req *http.Request, accountDB accounts.Database, device *api.Device,
req *http.Request, accountDB userdb.Database, device *api.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@ -170,7 +170,7 @@ func GetAssociated3PIDs(
}
// Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse {
var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr

View file

@ -22,6 +22,8 @@ import (
// whoamiResponse represents an response for a `whoami` request
type whoamiResponse struct {
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
IsGuest bool `json:"is_guest"`
}
// Whoami implements `/account/whoami` which enables client to query their account user id.
@ -29,6 +31,10 @@ type whoamiResponse struct {
func Whoami(req *http.Request, device *api.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusOK,
JSON: whoamiResponse{UserID: device.UserID},
JSON: whoamiResponse{
UserID: device.UserID,
DeviceID: device.ID,
IsGuest: device.AccountType == api.AccountTypeGuest,
},
}
}

View file

@ -29,7 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@ -87,7 +87,7 @@ var (
func CheckAndProcessInvite(
ctx context.Context,
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, db accounts.Database,
rsAPI api.RoomserverInternalAPI, db userdb.Database,
roomID string,
evTime time.Time,
) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed.
func queryIDServer(
ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite(
ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
)
const usage = `Usage: %s
@ -77,9 +77,14 @@ func main() {
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
accountDB, err := userdb.NewDatabase(
&config.DatabaseOptions{
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS)
},
cfg.Global.ServerName, bcrypt.DefaultCost,
cfg.UserAPI.OpenIDTokenLifetimeMS,
api.DefaultLoginTokenLifetime,
)
if err != nil {
logrus.Fatalln("Failed to connect to the database:", err.Error())
}

View file

@ -126,7 +126,6 @@ func main() {
cfg.FederationAPI.FederationMaxRetries = 6
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -160,7 +160,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -79,7 +79,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -164,7 +164,6 @@ func startup() {
cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"

View file

@ -167,7 +167,6 @@ func main() {
cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"

View file

@ -32,7 +32,6 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI)
}
cfg.Global.TrustedIDServers = []string{
"matrix.org",
@ -91,6 +90,7 @@ func main() {
cfg.Logging[0].Type = "std"
cfg.UserAPI.BCryptCost = bcrypt.MinCost
cfg.Global.JetStream.InMemory = true
cfg.ClientAPI.RegistrationSharedSecret = "complement"
}
j, err := yaml.Marshal(cfg)

View file

@ -8,12 +8,11 @@ import (
"log"
"os"
pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/pressly/goose"
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
@ -26,8 +25,7 @@ const (
RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi"
UserAPIAccounts = "userapi_accounts"
UserAPIDevices = "userapi_devices"
UserAPI = "userapi"
)
var (
@ -35,7 +33,7 @@ var (
flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name")
knownDBs = []string{
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices,
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
}
)
@ -143,18 +141,14 @@ Commands:
func loadSQLiteDeltas(component string) {
switch component {
case UserAPIAccounts:
slaccounts.LoadFromGoose()
case UserAPIDevices:
sldevices.LoadFromGoose()
case UserAPI:
slusers.LoadFromGoose()
}
}
func loadPostgresDeltas(component string) {
switch component {
case UserAPIAccounts:
pgaccounts.LoadFromGoose()
case UserAPIDevices:
pgdevices.LoadFromGoose()
case UserAPI:
pgusers.LoadFromGoose()
}
}

View file

@ -142,6 +142,10 @@ client_api:
# using the registration shared secret below.
registration_disabled: false
# Prevents new guest accounts from being created. Guest registration is also
# disabled implicitly by setting 'registration_disabled' above.
guests_disabled: true
# If set, allows registration by anyone who knows the shared secret, regardless of
# whether registration is otherwise disabled.
registration_shared_secret: ""
@ -204,13 +208,6 @@ federation_api:
# enable this option in production as it presents a security risk!
disable_tls_validation: false
# Use the following proxy server for outbound federation traffic.
proxy_outbound:
enabled: false
protocol: http
host: localhost
port: 8080
# Perspective keyservers to use as a backup when direct key fetches fail. This may
# be required to satisfy key requests for servers that are no longer online when
# joining some rooms.

4
go.mod
View file

@ -1,6 +1,6 @@
module github.com/matrix-org/dendrite
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
@ -45,7 +45,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.10
github.com/morikuni/aec v1.0.0 // indirect
github.com/nats-io/nats-server/v2 v2.3.2
github.com/nats-io/nats.go v1.13.1-0.20211122170419-d7c1d78a50fc
github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31

15
go.sum
View file

@ -1122,8 +1122,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nats-io/jwt/v2 v2.2.0 h1:Yg/4WFK6vsqMudRg91eBb7Dh6XeVcDMPHycDE8CfltE=
github.com/nats-io/jwt/v2 v2.2.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY=
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -1132,8 +1132,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 h1:BLQVdjMH5XD4BYb0fa+c2Oh2Nr1vrO7GKvRnIJDxChc=
github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423/go.mod h1:9sdEkBhyZMQG1M9TevnlYUwMusRACn2vlgOeqoHKwVo=
github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8=
github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
@ -1508,8 +1508,8 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1735,6 +1735,7 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc=
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -1756,10 +1757,10 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View file

@ -7,14 +7,6 @@ import (
)
const (
RoomServerStateKeyNIDsCacheName = "roomserver_statekey_nids"
RoomServerStateKeyNIDsCacheMaxEntries = 1024
RoomServerStateKeyNIDsCacheMutable = false
RoomServerEventTypeNIDsCacheName = "roomserver_eventtype_nids"
RoomServerEventTypeNIDsCacheMaxEntries = 64
RoomServerEventTypeNIDsCacheMutable = false
RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false
@ -29,44 +21,10 @@ type RoomServerCaches interface {
// RoomServerNIDsCache contains the subset of functions needed for
// a roomserver NID cache.
type RoomServerNIDsCache interface {
GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool)
StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID)
GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool)
StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID)
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
}
func (c Caches) GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) {
val, found := c.RoomServerStateKeyNIDs.Get(stateKey)
if found && val != nil {
if stateKeyNID, ok := val.(types.EventStateKeyNID); ok {
return stateKeyNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) {
c.RoomServerStateKeyNIDs.Set(stateKey, nid)
}
func (c Caches) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) {
val, found := c.RoomServerEventTypeNIDs.Get(eventType)
if found && val != nil {
if eventTypeNID, ok := val.(types.EventTypeNID); ok {
return eventTypeNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) {
c.RoomServerEventTypeNIDs.Set(eventType, nid)
}
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID)))
if found && val != nil {

View file

@ -6,8 +6,6 @@ package caching
type Caches struct {
RoomVersions Cache // RoomVersionCache
ServerKeys Cache // ServerKeyCache
RoomServerStateKeyNIDs Cache // RoomServerNIDsCache
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache

View file

@ -28,24 +28,6 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil {
return nil, err
}
roomServerStateKeyNIDs, err := NewInMemoryLRUCachePartition(
RoomServerStateKeyNIDsCacheName,
RoomServerStateKeyNIDsCacheMutable,
RoomServerStateKeyNIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerEventTypeNIDs, err := NewInMemoryLRUCachePartition(
RoomServerEventTypeNIDsCacheName,
RoomServerEventTypeNIDsCacheMutable,
RoomServerEventTypeNIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomIDsCacheName,
RoomServerRoomIDsCacheMutable,
@ -74,15 +56,12 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
return nil, err
}
go cacheCleaner(
roomVersions, serverKeys, roomServerStateKeyNIDs,
roomServerEventTypeNIDs, roomServerRoomIDs,
roomVersions, serverKeys, roomServerRoomIDs,
roomInfos, federationEvents,
)
return &Caches{
RoomVersions: roomVersions,
ServerKeys: serverKeys,
RoomServerStateKeyNIDs: roomServerStateKeyNIDs,
RoomServerEventTypeNIDs: roomServerEventTypeNIDs,
RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos,
FederationEvents: federationEvents,

View file

@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database)
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
cfg.EDUServer.InternalAPI.Listen = assignAddress()

View file

@ -367,10 +367,13 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
waitTime = fcerr.RetryAfter
} else if fcerr.Blacklisted {
waitTime = time.Hour * 8
} else {
// For all other errors (DNS resolution, network etc.) wait 1 hour.
waitTime = time.Hour
}
} else {
waitTime = time.Hour
logger.WithError(err).Warn("GetUserDevices returned unknown error type")
logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type")
}
continue
}

View file

@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
}
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domain := string(serverName)
// query local devices
if serverName == a.ThisServer {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
@ -525,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
@ -554,10 +554,58 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
}
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
// get a list of devices from the user API that actually exist, as
// we won't store keys for devices that don't exist
uapidevices := &userapi.QueryDevicesResponse{}
if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
return
}
if !uapidevices.UserExists {
res.Error = &api.KeyError{
Err: fmt.Sprintf("user %q does not exist", req.UserID),
}
return
}
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
for _, key := range uapidevices.Devices {
existingDeviceMap[key.ID] = struct{}{}
}
// Get all of the user existing device keys so we can check for changes.
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
}
return
}
// Work out whether we have device keys in the keyserver for devices that
// no longer exist in the user API. This is mostly an exercise to ensure
// that we keep some integrity between the two.
var toClean []gomatrixserverlib.KeyID
for _, k := range existingKeys {
if _, ok := existingDeviceMap[k.DeviceID]; !ok {
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
}
}
if len(toClean) > 0 {
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
} else {
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
}
}
var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
var serverName gomatrixserverlib.ServerName
_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
if err != nil {
continue // ignore invalid users
}
@ -568,6 +616,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking
}
// check that the device in question actually exists in the user
// API before we try and store a key for it
if _, ok := existingDeviceMap[key.DeviceID]; !ok {
continue
}
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
@ -583,29 +636,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
})
}
// get existing device keys so we can check for changes
existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore {
existingKeys[i] = api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID,
},
}
}
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
}
return
}
if req.OnlyDisplayNameUpdates {
// add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore)
}
// store the device keys and emit changes
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),

View file

@ -53,7 +53,7 @@ type Database interface {
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.

View file

@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -73,6 +76,7 @@ type deviceKeysStatements struct {
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}

View file

@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
})
}
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
}
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {

View file

@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -69,6 +72,7 @@ type deviceKeysStatements struct {
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}

View file

@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
}
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}

View file

@ -38,7 +38,7 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}

View file

@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
@ -32,6 +33,7 @@ type RoomserverInternalAPI struct {
*perform.Publisher
*perform.Backfiller
*perform.Forgetter
ProcessContext *process.ProcessContext
DB storage.Database
Cfg *config.RoomServer
Cache caching.RoomServerCaches
@ -48,12 +50,13 @@ type RoomserverInternalAPI struct {
}
func NewRoomserverAPI(
cfg *config.RoomServer, roomserverDB storage.Database, consumer nats.JetStreamContext,
inputRoomEventTopic, outputRoomEventTopic string, caches caching.RoomServerCaches,
perspectiveServerNames []gomatrixserverlib.ServerName,
processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database,
consumer nats.JetStreamContext, inputRoomEventTopic, outputRoomEventTopic string,
caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName,
) *RoomserverInternalAPI {
serverACLs := acls.NewServerACLs(roomserverDB)
a := &RoomserverInternalAPI{
ProcessContext: processCtx,
DB: roomserverDB,
Cfg: cfg,
Cache: caches,
@ -83,6 +86,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
r.KeyRing = keyRing
r.Inputer = &input.Inputer{
ProcessContext: r.ProcessContext,
DB: r.DB,
InputRoomEventTopic: r.InputRoomEventTopic,
OutputRoomEventTopic: r.OutputRoomEventTopic,

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/prometheus/client_golang/prometheus"
@ -59,6 +60,7 @@ var keyContentFields = map[string]string{
}
type Inputer struct {
ProcessContext *process.ProcessContext
DB storage.Database
JetStream nats.JetStreamContext
Durable nats.SubOpt
@ -115,7 +117,7 @@ func (r *Inputer) Start() error {
_ = msg.InProgress() // resets the acknowledgement wait timer
defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent)
action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err)

View file

@ -405,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro
if len(extraEventIDs) == 0 {
return nil, nil
}
extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs)
extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs)
if err != nil {
return nil, err
}

View file

@ -53,7 +53,7 @@ func NewInternalAPI(
js := jetstream.Prepare(&cfg.Matrix.JetStream)
return internal.NewRoomserverAPI(
cfg, roomserverDB, js,
base.ProcessContext, cfg, roomserverDB, js,
cfg.Matrix.JetStream.TopicFor(jetstream.InputRoomEvent),
cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent),
base.Caches, perspectiveServerNames,

View file

@ -127,6 +127,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)"
const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE"
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
@ -147,6 +150,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
bulkSelectUnsentEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
@ -173,6 +177,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
}.Prepare(db)
@ -458,10 +463,28 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
return results, nil
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
}
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
var stmt *sql.Stmt
if onlyUnsent {
stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt)
} else {
stmt = sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err

View file

@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
}
// Look up the NID of the new join event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID})
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
}
// Look up the NID of the new leave event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID})
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er
}
if u.membership != tables.MembershipStateKnock {
// Look up the NID of the new knock event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()})
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}

View file

@ -215,7 +215,13 @@ func (u *RoomUpdater) EventIDs(
func (u *RoomUpdater) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return u.d.eventNIDs(ctx, u.txn, eventIDs)
return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter)
}
func (u *RoomUpdater) UnsentEventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly)
}
func (u *RoomUpdater) StateAtEventIDs(
@ -231,7 +237,11 @@ func (u *RoomUpdater) StateEntriesForEventIDs(
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
}
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(

View file

@ -59,23 +59,12 @@ func (d *Database) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID)
remaining := []string{}
for _, eventType := range eventTypes {
if nid, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
result[eventType] = nid
} else {
remaining = append(remaining, eventType)
}
}
if len(remaining) > 0 {
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes)
if err != nil {
return nil, err
}
for eventType, nid := range nids {
result[eventType] = nid
d.Cache.StoreRoomServerEventTypeNID(eventType, nid)
}
}
return result, nil
}
@ -96,23 +85,12 @@ func (d *Database) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID)
remaining := []string{}
for _, eventStateKey := range eventStateKeys {
if nid, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
result[eventStateKey] = nid
} else {
remaining = append(remaining, eventStateKey)
}
}
if len(remaining) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
if err != nil {
return nil, err
}
for eventStateKey, nid := range nids {
result[eventStateKey] = nid
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, nid)
}
}
return result, nil
}
@ -238,13 +216,27 @@ func (d *Database) addState(
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.eventNIDs(ctx, nil, eventIDs)
return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
}
type UnsentFilter bool
const (
NoFilter UnsentFilter = false
FilterUnsentOnly UnsentFilter = true
)
func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventNID, error) {
switch filter {
case FilterUnsentOnly:
return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs)
case NoFilter:
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
default:
panic("impossible case")
}
}
func (d *Database) SetState(
@ -281,11 +273,11 @@ func (d *Database) EventIDs(
}
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return d.eventsFromIDs(ctx, nil, eventIDs)
return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter)
}
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil {
return nil, err
}
@ -704,9 +696,6 @@ func (d *Database) assignRoomNID(
func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
if eventTypeNID, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
return eventTypeNID, nil
}
// Check if we already have a numeric ID in the database.
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
@ -717,18 +706,12 @@ func (d *Database) assignEventTypeNID(
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
}
}
if err == nil {
d.Cache.StoreRoomServerEventTypeNID(eventType, eventTypeNID)
}
return eventTypeNID, err
}
func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
if eventStateKeyNID, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
return eventStateKeyNID, nil
}
// Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
@ -739,9 +722,6 @@ func (d *Database) assignStateKeyNID(
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
}
}
if err == nil {
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, eventStateKeyNID)
}
return eventStateKeyNID, err
}

View file

@ -99,6 +99,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)"
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
@ -118,7 +121,8 @@ type eventStatements struct {
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
//bulkSelectEventNIDStmt *sql.Stmt
//bulkSelectUnsentEventNIDStmt *sql.Stmt
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
@ -144,7 +148,8 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
//{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
//{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
}.Prepare(db)
}
@ -494,15 +499,33 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
return results, nil
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
}
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
iEventIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
var selectOrig string
if onlyUnsent {
selectOrig = strings.Replace(bulkSelectUnsentEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
} else {
selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
}
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err

View file

@ -59,6 +59,7 @@ type Events interface {
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
}

View file

@ -38,7 +38,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/gorilla/mux"
@ -273,8 +273,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
db, err := userdb.NewDatabase(
&b.Cfg.UserAPI.AccountDatabase,
b.Cfg.Global.ServerName,
b.Cfg.UserAPI.BCryptCost,
b.Cfg.UserAPI.OpenIDTokenLifetimeMS,
userapi.DefaultLoginTokenLifetime,
)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db")
}

View file

@ -18,6 +18,10 @@ type ClientAPI struct {
// If set, allows registration by anyone who also has the shared
// secret, even if registration is otherwise disabled.
RegistrationSharedSecret string `yaml:"registration_shared_secret"`
// If set, prevents guest accounts from being created. Only takes
// effect if registration is enabled, otherwise guests registration
// is forbidden either way.
GuestsDisabled bool `yaml:"guests_disabled"`
// Boolean stating whether catpcha registration is enabled
// and required

View file

@ -29,8 +29,6 @@ type FederationAPI struct {
// on remote federation endpoints. This is not recommended in production!
DisableTLSValidation bool `yaml:"disable_tls_validation"`
Proxy Proxy `yaml:"proxy_outbound"`
// Perspective keyservers, to use as a backup when direct key fetch
// requests don't succeed
KeyPerspectives KeyPerspectives `yaml:"key_perspectives"`
@ -50,8 +48,6 @@ func (c *FederationAPI) Defaults(generate bool) {
c.FederationMaxRetries = 16
c.DisableTLSValidation = false
c.Proxy.Defaults()
}
func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {

View file

@ -118,11 +118,6 @@ federation_sender:
conn_max_lifetime: -1
send_max_retries: 16
disable_tls_validation: false
proxy_outbound:
enabled: false
protocol: http
host: localhost
port: 8080
key_server:
internal_api:
listen: http://localhost:7779

View file

@ -16,9 +16,6 @@ type UserAPI struct {
// The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"`
// The Device database stores session information for the devices of logged
// in local users. It is accessed by the UserAPI.
DeviceDatabase DatabaseOptions `yaml:"device_database"`
}
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) {
c.InternalAPI.Listen = "http://localhost:7781"
c.InternalAPI.Connect = "http://localhost:7781"
c.AccountDatabase.Defaults(10)
c.DeviceDatabase.Defaults(10)
if generate {
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
}
c.BCryptCost = bcrypt.DefaultCost
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
}

View file

@ -29,7 +29,6 @@ func Prepare(cfg *config.JetStream) natsclient.JetStreamContext {
JetStream: true,
StoreDir: string(cfg.StoragePath),
NoSystemAccount: true,
AllowNewAccounts: false,
MaxPayload: 16 * 1024 * 1024,
})
if err != nil {

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@ -38,7 +38,7 @@ import (
// all components of Dendrite, for use in monolith mode.
type Monolith struct {
Config *config.Dendrite
AccountDB accounts.Database
AccountDB userdb.Database
KeyRing *gomatrixserverlib.KeyRing
Client *gomatrixserverlib.Client
FedClient *gomatrixserverlib.FederationClient

View file

@ -39,14 +39,14 @@ func Setup(
rsAPI api.RoomserverInternalAPI,
cfg *config.SyncAPI,
) {
r0mux := csMux.PathPrefix("/r0").Subrouter()
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
// TODO: Add AS support for all handlers below.
r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -54,7 +54,7 @@ func Setup(
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter",
v3mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -64,7 +64,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}",
v3mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -74,7 +74,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingKeyChangeRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions)
}

View file

@ -279,7 +279,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
parts := strings.Split(tok[1:], "_")
var positions [7]StreamPosition
for i, p := range parts {
if i > len(positions) {
if i >= len(positions) {
break
}
var pos int

View file

@ -19,6 +19,13 @@ import (
"time"
)
// DefaultLoginTokenLifetime determines how old a valid token may be.
//
// NOTSPEC: The current spec says "SHOULD be limited to around five
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
const DefaultLoginTokenLifetime = 2 * time.Minute
type LoginTokenInternalAPI interface {
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error

View file

@ -31,13 +31,11 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/dendrite/userapi/storage"
)
type UserInternalAPI struct {
AccountDB accounts.Database
DeviceDB devices.Database
DB storage.Database
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
}
if err != nil {
return err
@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
}
prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local)
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil {
if err == sql.ErrNoRows {
return nil
@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
}
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
}
devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local)
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil {
return err
}
res.UserExists = true
res.Devices = devs
return nil
}
@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil {
return err
}
@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
global, rooms, err := a.AccountDB.GetAccountData(ctx, local)
global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil {
return err
}
@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken)
device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil {
if err == sql.ErrNoRows {
return nil
@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if err != nil {
return err
}
acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart)
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
err := a.AccountDB.DeactivateAccount(ctx, req.Localpart)
err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil
return err
}
@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{
Token: token,
@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil {
return err
}
@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
return nil
}
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
}
@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Create metadata
if req.Version == "" {
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err)
}
@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Update metadata
if len(req.Keys.Rooms) == 0 {
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err)
return
@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
})
}
}
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return
@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted
if !req.ReturnKeys {
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
}
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return

View file

@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
}
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data)
tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil {
return err
}
@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
// PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
return a.DeviceDB.RemoveLoginToken(ctx, req.Token)
return a.DB.RemoveLoginToken(ctx, req.Token)
}
// QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token)
tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil {
res.Data = nil
if err == sql.ErrNoRows {
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
}
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil {
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil

View file

@ -1,517 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, sqlutil.ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
var account *api.Account
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = d.hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists
}
return nil, err
}
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)); err != nil {
return nil, err
}
return account, nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart)
}
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
if err != nil {
return err
}
if deleted {
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
if err != nil {
return err
}
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
}
}
// if we shouldn't replace the key we do nothing with it
continue
}
}
// if we're here, either the room or session are new, either way, we insert
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
}
}
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}

View file

@ -1,52 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}

View file

@ -1,270 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,271 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
writer sqlutil.Writer
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, writer, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,42 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters. loginTokenLifetime determines how long a
// login token from CreateLoginToken is valid.
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,39 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(
dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName,
loginTokenLifetime time.Duration,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package storage
import (
"context"
@ -60,6 +60,35 @@ type Database interface {
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const accountDataSchema = `
@ -56,19 +57,20 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(accountDataSchema)
func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
s := &accountDataStatements{}
_, err := db.Exec(accountDataSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertAccountDataStmt, insertAccountDataSQL},
{&s.selectAccountDataStmt, selectAccountDataSQL},
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
}.Prepare(db)
}
func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData(
return
}
func (s *accountDataStatements) selectAccountData(
func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string,
) (
/* global */ map[string]json.RawMessage,
@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
return global, rooms, rows.Err()
}
func (s *accountDataStatements) selectAccountDataByType(
func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
)
@ -78,14 +79,15 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(accountsSchema)
return err
func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
s := &accountsStatements{
serverName: serverName,
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.serverName = server
return sqlutil.StatementList{
_, err := db.Exec(accountsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL},
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount(
}, nil
}
func (s *accountsStatements) updatePassword(
func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) deactivateAccount(
func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectAccountByLocalpart(
func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil
}
func (s *accountsStatements) selectNewNumericLocalpart(
func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt

View file

@ -5,13 +5,8 @@ import (
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@ -111,50 +112,32 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) execSchema(db *sql.DB) error {
func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
serverName: serverName,
}
_, err := db.Exec(devicesSchema)
return err
if err != nil {
return nil, err
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL},
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
{&s.deleteDeviceStmt, deleteDeviceSQL},
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
{&s.deleteDevicesStmt, deleteDevicesSQL},
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
}.Prepare(db)
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
@ -176,7 +159,7 @@ func (s *devicesStatements) insertDevice(
}
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice(
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -186,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices(
func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
@ -196,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -204,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err
}
func (s *devicesStatements) updateDeviceName(
func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -212,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
return err
}
func (s *devicesStatements) selectDeviceByToken(
func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
var dev api.Device
@ -228,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
@ -245,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil {
return nil, err
@ -268,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err()
}
func (s *devicesStatements) selectDevicesByLocalpart(
func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
@ -310,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, rows.Err()
}
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const keyBackupTableSchema = `
@ -72,12 +73,13 @@ type keyBackupStatements struct {
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
}
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
s := &keyBackupStatements{}
_, err := db.Exec(keyBackupTableSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s keyBackupStatements) countKeys(
func (s keyBackupStatements) CountKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return
}
func (s *keyBackupStatements) insertBackupKey(
func (s *keyBackupStatements) InsertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
return
}
func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) UpdateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
return
}
func (s *keyBackupStatements) selectKeys(
func (s *keyBackupStatements) SelectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomID(
func (s *keyBackupStatements) SelectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)

View file

@ -22,6 +22,7 @@ import (
"strconv"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const keyBackupVersionTableSchema = `
@ -69,12 +70,13 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt
}
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
s := &keyBackupVersionStatements{}
_, err := db.Exec(keyBackupVersionTableSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *keyBackupVersionStatements) insertKeyBackup(
func (s *keyBackupVersionStatements) InsertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) {
var versionInt int64
@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return strconv.FormatInt(versionInt, 10), err
}
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err
}
func (s *keyBackupVersionStatements) updateKeyBackupETag(
func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup(
func (s *keyBackupVersionStatements) DeleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
return ra == 1, nil
}
func (s *keyBackupVersionStatements) selectKeyBackup(
func (s *keyBackupVersionStatements) SelectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64

View file

@ -21,18 +21,11 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/util"
)
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt
}
// execSchema ensures tables and indices exist.
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
@ -45,21 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens (
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`)
return err
`
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectStmt *sql.Stmt
}
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
s := &loginTokenStatements{}
_, err := db.Exec(loginTokenSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertStmt, insertLoginTokenSQL},
{&s.deleteStmt, deleteLoginTokenSQL},
{&s.selectStmt, selectLoginTokenSQL},
}.Prepare(db)
}
// insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err
@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil {
@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
}
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil {
return nil, err
}

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
);
`
const insertTokenSQL = "" +
const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct {
type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return
func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
s := &openIDTokenStatements{
serverName: serverName,
}
s.serverName = server
return sqlutil.StatementList{
{&s.insertTokenStmt, insertTokenSQL},
{&s.selectTokenStmt, selectTokenSQL},
_, err := db.Exec(openIDTokenSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertTokenStmt, insertOpenIDTokenSQL},
{&s.selectTokenStmt, selectOpenIDTokenSQL},
}.Prepare(db)
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken(
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken(
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributes(
func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const profilesSchema = `
@ -59,12 +60,13 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(profilesSchema)
func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
s := &profilesStatements{}
_, err := db.Exec(profilesSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertProfileStmt, insertProfileSQL},
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
{&s.setAvatarURLStmt, setAvatarURLSQL},
@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *profilesStatements) insertProfile(
func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return
}
func (s *profilesStatements) selectProfileByLocalpart(
func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart(
return &profile, nil
}
func (s *profilesStatements) setAvatarURL(
ctx context.Context, localpart string, avatarURL string,
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
}
func (s *profilesStatements) setDisplayName(
ctx context.Context, localpart string, displayName string,
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
}
func (s *profilesStatements) selectProfilesBySearch(
func (s *profilesStatements) SelectProfilesBySearch(
ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile

View file

@ -0,0 +1,105 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
if _, err = db.Exec(accountsSchema); err != nil {
// do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err
}
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
}
devicesTable, err := NewPostgresDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
}
keyBackupTable, err := NewPostgresKeyBackupTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err)
}
keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err)
}
loginTokenTable, err := NewPostgresLoginTokenTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err)
}
openIDTable, err := NewPostgresOpenIDTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err)
}
profilesTable, err := NewPostgresProfilesTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err)
}
threePIDTable, err := NewPostgresThreePIDTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
Devices: devicesTable,
KeyBackups: keyBackupTable,
KeyBackupVersions: keyBackupVersionTable,
LoginTokens: loginTokenTable,
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewDummyWriter(),
LoginTokenLifetime: loginTokenLifetime,
BcryptCost: bcryptCost,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -58,12 +59,13 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(threepidSchema)
func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
s := &threepidStatements{}
_, err := db.Exec(threepidSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
{&s.insertThreePIDStmt, insertThreePIDSQL},
@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *threepidStatements) selectLocalpartForThreePID(
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
return
}
func (s *threepidStatements) selectThreePIDsForLocalpart(
func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
return
}
func (s *threepidStatements) insertThreePID(
func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID(
return
}
func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
func (s *threepidStatements) DeleteThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium)
return
}

View file

@ -12,16 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
package shared
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strconv"
"sync"
"time"
"github.com/matrix-org/gomatrixserverlib"
@ -29,102 +30,48 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
// Database represents an account database
type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
accountsMu sync.Mutex
profilesMu sync.Mutex
accountDatasMu sync.Mutex
threepidsMu sync.Mutex
DB *sql.DB
Writer sqlutil.Writer
Accounts tables.AccountsTable
Profiles tables.ProfileTable
AccountDatas tables.AccountDataTable
ThreePIDs tables.ThreePIDTable
OpenIDTokens tables.OpenIDTable
KeyBackups tables.KeyBackupTable
KeyBackupVersions tables.KeyBackupVersionTable
Devices tables.DevicesTable
LoginTokens tables.LoginTokenTable
LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName
BcryptCost int
OpenIDTokenLifetimeMS int64
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
partitions := sqlutil.PartitionOffsetStatements{}
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
@ -132,7 +79,7 @@ func (d *Database) GetAccountByPassword(
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
@ -140,10 +87,8 @@ func (d *Database) GetProfileByLocalpart(
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
d.profilesMu.Lock()
defer d.profilesMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
})
}
@ -152,10 +97,8 @@ func (d *Database) SetAvatarURL(
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
d.profilesMu.Lock()
defer d.profilesMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
})
}
@ -167,8 +110,8 @@ func (d *Database) SetPassword(
if err != nil {
return err
}
return d.writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.accounts.updatePassword(ctx, localpart, hash)
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePassword(ctx, localpart, hash)
})
}
@ -178,18 +121,11 @@ func (d *Database) SetPassword(
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'.
d.profilesMu.Lock()
d.accountDatasMu.Lock()
d.accountsMu.Lock()
defer d.profilesMu.Unlock()
defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
@ -218,13 +154,13 @@ func (d *Database) createAccount(
return nil, err
}
}
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists
}
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@ -246,10 +182,8 @@ func (d *Database) createAccount(
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
d.accountDatasMu.Lock()
defer d.accountDatasMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
@ -261,7 +195,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
return d.AccountDatas.SelectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
@ -271,7 +205,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
return d.AccountDatas.SelectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
@ -280,11 +214,11 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost)
return string(hashBytes), err
}
@ -299,10 +233,8 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
d.threepidsMu.Lock()
defer d.threepidsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
@ -313,7 +245,7 @@ func (d *Database) SaveThreePIDAssociation(
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
})
}
@ -324,10 +256,8 @@ func (d *Database) SaveThreePIDAssociation(
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
d.threepidsMu.Lock()
defer d.threepidsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium)
})
}
@ -339,7 +269,7 @@ func (d *Database) RemoveThreePIDAssociation(
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
@ -349,14 +279,14 @@ func (d *Database) GetLocalpartForThreePID(
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
@ -368,20 +298,20 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit)
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.accounts.deactivateAccount(ctx, localpart)
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.DeactivateAccount(ctx, localpart)
})
}
@ -390,9 +320,9 @@ func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
@ -402,14 +332,14 @@ func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
@ -418,8 +348,8 @@ func (d *Database) CreateKeyBackup(
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
@ -427,8 +357,8 @@ func (d *Database) UpdateKeyBackupAuthData(
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version)
return err
})
return
@ -437,8 +367,8 @@ func (d *Database) DeleteKeyBackup(
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
return err
})
return
@ -447,16 +377,16 @@ func (d *Database) GetKeyBackup(
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version)
return err
})
return
@ -465,8 +395,8 @@ func (d *Database) GetBackupKeys(
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
if err != nil {
return err
}
@ -480,8 +410,8 @@ func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
if err != nil {
return err
}
@ -489,7 +419,7 @@ func (d *Database) UpsertBackupKeys(
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version)
if err != nil {
return err
}
@ -503,10 +433,10 @@ func (d *Database) UpsertBackupKeys(
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err)
}
}
// if we shouldn't replace the key we do nothing with it
@ -514,14 +444,14 @@ func (d *Database) UpsertBackupKeys(
}
}
// if we're here, either the room or session are new, either way, we insert
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err)
}
}
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
if err != nil {
return err
}
@ -538,7 +468,7 @@ func (d *Database) UpsertBackupKeys(
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
@ -547,3 +477,196 @@ func (d *Database) UpsertBackupKeys(
})
return
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.Devices.SelectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.LoginTokenLifetime),
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.LoginTokens.DeleteLoginToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.LoginTokens.SelectLoginToken(ctx, token)
}

View file

@ -20,6 +20,7 @@ import (
"encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const accountDataSchema = `
@ -56,27 +57,29 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(accountDataSchema)
if err != nil {
return
func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
s := &accountDataStatements{
db: db,
}
return sqlutil.StatementList{
_, err := db.Exec(accountDataSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertAccountDataStmt, insertAccountDataSQL},
{&s.selectAccountDataStmt, selectAccountDataSQL},
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
}.Prepare(db)
}
func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) error {
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return err
}
func (s *accountDataStatements) selectAccountData(
func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string,
) (
/* global */ map[string]json.RawMessage,
@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
return global, rooms, nil
}
func (s *accountDataStatements) selectAccountDataByType(
func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
)
@ -77,15 +78,16 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(accountsSchema)
return err
func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
s := &accountsStatements{
db: db,
serverName: serverName,
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.serverName = server
return sqlutil.StatementList{
_, err := db.Exec(accountsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL},
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
@ -122,28 +124,28 @@ func (s *accountsStatements) insertAccount(
}, nil
}
func (s *accountsStatements) updatePassword(
func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) deactivateAccount(
func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectAccountByLocalpart(
func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
@ -167,7 +169,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil
}
func (s *accountsStatements) selectNewNumericLocalpart(
func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt

View file

@ -5,13 +5,8 @@ import (
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" +
type devicesStatements struct {
db *sql.DB
writer sqlutil.Writer
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@ -98,52 +98,33 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) execSchema(db *sql.DB) error {
func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
db: db,
serverName: serverName,
}
_, err := db.Exec(devicesSchema)
return err
if err != nil {
return nil, err
}
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = writer
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL},
{&s.selectDevicesCountStmt, selectDevicesCountSQL},
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
{&s.deleteDeviceStmt, deleteDeviceSQL},
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
}.Prepare(db)
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
@ -169,7 +150,7 @@ func (s *devicesStatements) insertDevice(
}, nil
}
func (s *devicesStatements) deleteDevice(
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -177,7 +158,7 @@ func (s *devicesStatements) deleteDevice(
return err
}
func (s *devicesStatements) deleteDevices(
func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
@ -195,7 +176,7 @@ func (s *devicesStatements) deleteDevices(
return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -203,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err
}
func (s *devicesStatements) updateDeviceName(
func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -211,7 +192,7 @@ func (s *devicesStatements) updateDeviceName(
return err
}
func (s *devicesStatements) selectDeviceByToken(
func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
var dev api.Device
@ -227,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
@ -244,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err
}
func (s *devicesStatements) selectDevicesByLocalpart(
func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
@ -285,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs {
@ -314,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err()
}
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const keyBackupTableSchema = `
@ -72,12 +73,13 @@ type keyBackupStatements struct {
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
}
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
s := &keyBackupStatements{}
_, err := db.Exec(keyBackupTableSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s keyBackupStatements) countKeys(
func (s keyBackupStatements) CountKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return
}
func (s *keyBackupStatements) insertBackupKey(
func (s *keyBackupStatements) InsertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
return
}
func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) UpdateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
return
}
func (s *keyBackupStatements) selectKeys(
func (s *keyBackupStatements) SelectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomID(
func (s *keyBackupStatements) SelectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)

View file

@ -22,6 +22,7 @@ import (
"strconv"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const keyBackupVersionTableSchema = `
@ -67,12 +68,13 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt
}
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
s := &keyBackupVersionStatements{}
_, err := db.Exec(keyBackupVersionTableSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *keyBackupVersionStatements) insertKeyBackup(
func (s *keyBackupVersionStatements) InsertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) {
var versionInt int64
@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return strconv.FormatInt(versionInt, 10), err
}
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err
}
func (s *keyBackupVersionStatements) updateKeyBackupETag(
func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup(
func (s *keyBackupVersionStatements) DeleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
return ra == 1, nil
}
func (s *keyBackupVersionStatements) selectKeyBackup(
func (s *keyBackupVersionStatements) SelectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64

View file

@ -21,18 +21,17 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/util"
)
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt
selectStmt *sql.Stmt
}
// execSchema ensures tables and indices exist.
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
@ -45,21 +44,32 @@ CREATE TABLE IF NOT EXISTS login_tokens (
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`)
return err
}
`
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
s := &loginTokenStatements{}
_, err := db.Exec(loginTokenSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertStmt, insertLoginTokenSQL},
{&s.deleteStmt, deleteLoginTokenSQL},
{&s.selectStmt, selectLoginTokenSQL},
}.Prepare(db)
}
// insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err
@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil {
@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
}
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil {
return nil, err
}

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
);
`
const insertTokenSQL = "" +
const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct {
type openIDTokenStatements struct {
db *sql.DB
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return err
func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
s := &openIDTokenStatements{
db: db,
serverName: serverName,
}
s.serverName = server
return sqlutil.StatementList{
{&s.insertTokenStmt, insertTokenSQL},
{&s.selectTokenStmt, selectTokenSQL},
_, err := db.Exec(openIDTokenSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertTokenStmt, insertOpenIDTokenSQL},
{&s.selectTokenStmt, selectOpenIDTokenSQL},
}.Prepare(db)
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken(
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken(
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributes(
func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const profilesSchema = `
@ -60,13 +61,15 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(profilesSchema)
if err != nil {
return
func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
s := &profilesStatements{
db: db,
}
return sqlutil.StatementList{
_, err := db.Exec(profilesSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertProfileStmt, insertProfileSQL},
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
{&s.setAvatarURLStmt, setAvatarURLSQL},
@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *profilesStatements) insertProfile(
func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return err
}
func (s *profilesStatements) selectProfileByLocalpart(
func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart(
return &profile, nil
}
func (s *profilesStatements) setAvatarURL(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL(
return
}
func (s *profilesStatements) setDisplayName(
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName(
return
}
func (s *profilesStatements) selectProfilesBySearch(
func (s *profilesStatements) SelectProfilesBySearch(
ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile

View file

@ -0,0 +1,106 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
if _, err = db.Exec(accountsSchema); err != nil {
// do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err
}
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
}
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
}
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
}
keyBackupTable, err := NewSQLiteKeyBackupTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err)
}
keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err)
}
loginTokenTable, err := NewSQLiteLoginTokenTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err)
}
openIDTable, err := NewSQLiteOpenIDTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err)
}
profilesTable, err := NewSQLiteProfilesTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err)
}
threePIDTable, err := NewSQLiteThreePIDTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
Devices: devicesTable,
KeyBackups: keyBackupTable,
KeyBackupVersions: keyBackupVersionTable,
LoginTokens: loginTokenTable,
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
LoginTokenLifetime: loginTokenLifetime,
BcryptCost: bcryptCost,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -60,13 +61,15 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(threepidSchema)
if err != nil {
return
func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
s := &threepidStatements{
db: db,
}
return sqlutil.StatementList{
_, err := db.Exec(threepidSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
{&s.insertThreePIDStmt, insertThreePIDSQL},
@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *threepidStatements) selectLocalpartForThreePID(
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
return
}
func (s *threepidStatements) selectThreePIDsForLocalpart(
func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
return threepids, rows.Err()
}
func (s *threepidStatements) insertThreePID(
func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID(
return err
}
func (s *threepidStatements) deleteThreePID(
func (s *threepidStatements) DeleteThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium)

View file

@ -15,26 +15,27 @@
//go:build !wasm
// +build !wasm
package accounts
package storage
import (
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
"github.com/matrix-org/dendrite/userapi/storage/postgres"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
default:
return nil, fmt.Errorf("unexpected database type")
}

View file

@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package storage
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
@ -27,10 +28,11 @@ func NewDatabase(
serverName gomatrixserverlib.ServerName,
bcryptCost int,
openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:

Some files were not shown because too many files have changed in this diff Show more