Merge branch 'main' into implement-push-notifications

This commit is contained in:
Neil Alexander 2022-02-18 14:16:28 +00:00
commit 857b75d66e
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
87 changed files with 1338 additions and 2086 deletions

View file

@ -23,7 +23,7 @@ import (
"errors" "errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -85,7 +85,7 @@ func RetrieveUserProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
asAPI AppServiceQueryAPI, asAPI AppServiceQueryAPI,
accountDB accounts.Database, accountDB userdb.Database,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -283,8 +283,7 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))

View file

@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))

View file

@ -29,7 +29,7 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -38,7 +38,7 @@ func AddPublicRoutes(
router *mux.Router, router *mux.Router,
synapseAdminRouter *mux.Router, synapseAdminRouter *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountsDB accounts.Database, accountsDB userdb.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
eduInputAPI eduServerAPI.EDUServerInputAPI, eduInputAPI eduServerAPI.EDUServerInputAPI,

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -137,7 +137,7 @@ type fledglingEvent struct {
func CreateRoom( func CreateRoom(
req *http.Request, device *api.Device, req *http.Request, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
@ -151,7 +151,7 @@ func CreateRoom(
func createRoom( func createRoom(
req *http.Request, device *api.Device, req *http.Request, device *api.Device,
cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI, roomID string,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomIDOrAlias string, roomIDOrAlias string,
) util.JSONResponse { ) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join. // Prepare to ask the roomserver to perform the room join.

View file

@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -36,7 +36,7 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys( func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive, req *http.Request, userInteractiveAuth *auth.UserInteractive,
keyserverAPI api.KeyInternalAPI, device *userapi.Device, keyserverAPI api.KeyInternalAPI, device *userapi.Device,
accountDB accounts.Database, cfg *config.ClientAPI, accountDB userdb.Database, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
uploadReq := &crossSigningRequest{} uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{} uploadRes := &api.PerformUploadDeviceKeysResponse{}

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -54,7 +54,7 @@ func passwordLogin() flows {
// Login implements GET and POST /login // Login implements GET and POST /login
func Login( func Login(
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { if req.Method == http.MethodGet {

View file

@ -30,7 +30,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -39,7 +39,7 @@ import (
var errMissingUserID = errors.New("'user_id' must be supplied") var errMissingUserID = errors.New("'user_id' must be supplied")
func SendBan( func SendBan(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -81,7 +81,7 @@ func SendBan(
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
} }
func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device, func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device,
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
roomVer gomatrixserverlib.RoomVersion, roomVer gomatrixserverlib.RoomVersion,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
} }
func SendKick( func SendKick(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -165,7 +165,7 @@ func SendKick(
} }
func SendUnban( func SendUnban(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -200,7 +200,7 @@ func SendUnban(
} }
func SendInvite( func SendInvite(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -271,7 +271,7 @@ func SendInvite(
func buildMembershipEvent( func buildMembershipEvent(
ctx context.Context, ctx context.Context,
targetUserID, reason string, accountDB accounts.Database, targetUserID, reason string, accountDB userdb.Database,
device *userapi.Device, device *userapi.Device,
membership, roomID string, isDirect bool, membership, roomID string, isDirect bool,
cfg *config.ClientAPI, evTime time.Time, cfg *config.ClientAPI, evTime time.Time,
@ -312,7 +312,7 @@ func loadProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, accountDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID) _, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -366,7 +366,7 @@ func checkAndProcessThreepid(
body *threepid.MembershipRequest, body *threepid.MembershipRequest,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomID string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) { ) (inviteStored bool, errRes *util.JSONResponse) {

View file

@ -10,7 +10,7 @@ import (
pushserverapi "github.com/matrix-org/dendrite/pushserver/api" pushserverapi "github.com/matrix-org/dendrite/pushserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -32,7 +32,7 @@ func Password(
req *http.Request, req *http.Request,
psAPI pushserverapi.PushserverInternalAPI, psAPI pushserverapi.PushserverInternalAPI,
userAPI api.UserInternalAPI, userAPI api.UserInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
device *api.Device, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -19,7 +19,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomIDOrAlias string, roomIDOrAlias string,
) util.JSONResponse { ) util.JSONResponse {
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
@ -82,7 +82,7 @@ func UnpeekRoomByID(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{ unpeekReq := roomserverAPI.PerformUnpeekRequest{

View file

@ -27,7 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID} // GetProfile implements GET /profile/{userID}
func GetProfile( func GetProfile(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -65,7 +65,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url // GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL( func GetAvatarURL(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -92,7 +92,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url // SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL( func SetAvatarURL(
req *http.Request, accountDB accounts.Database, req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -182,7 +182,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname // GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName( func GetDisplayName(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -209,7 +209,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName( func SetDisplayName(
req *http.Request, accountDB accounts.Database, req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -302,7 +302,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically // Returns an error when something goes wrong or specifically
// eventutil.ErrProfileNoExists when the profile doesn't exist. // eventutil.ErrProfileNoExists when the profile doesn't exist.
func getProfile( func getProfile(
ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI, ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,

View file

@ -44,7 +44,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
) )
var ( var (
@ -448,7 +448,7 @@ func validateApplicationService(
func Register( func Register(
req *http.Request, req *http.Request,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
var r registerRequest var r registerRequest
@ -532,6 +532,13 @@ func handleGuestRegistration(
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if cfg.RegistrationDisabled || cfg.GuestsDisabled {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Guest registration is disabled"),
}
}
var res userapi.PerformAccountCreationResponse var res userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeGuest, AccountType: userapi.AccountTypeGuest,
@ -892,7 +899,7 @@ type availableResponse struct {
func RegisterAvailable( func RegisterAvailable(
req *http.Request, req *http.Request,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, accountDB userdb.Database,
) util.JSONResponse { ) util.JSONResponse {
username := req.URL.Query().Get("username") username := req.URL.Query().Get("username")

View file

@ -34,7 +34,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -51,7 +51,7 @@ func Setup(
eduAPI eduServerAPI.EDUServerInputAPI, eduAPI eduServerAPI.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
accountDB accounts.Database, accountDB userdb.Database,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer, syncProducer *producers.SyncAPIProducer,
@ -118,15 +118,22 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
} }
r0mux := publicAPIMux.PathPrefix("/r0").Subrouter() // You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables.
// So make a named path variable called 'apiversion' (which we will never read in handlers) and then do
// (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group
// from a match (gorilla/mux exposes no way to do this) so it demands you make it a non-capturing group
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
r0mux.Handle("/createRoom", v3mux.Handle("/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI) return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/join/{roomIDOrAlias}", v3mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -142,7 +149,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
if mscCfg.Enabled("msc2753") { if mscCfg.Enabled("msc2753") {
r0mux.Handle("/peek/{roomIDOrAlias}", v3mux.Handle("/peek/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -157,12 +164,12 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
} }
r0mux.Handle("/joined_rooms", v3mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, rsAPI) return GetJoinedRooms(req, device, rsAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/join", v3mux.Handle("/rooms/{roomID}/join",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -176,7 +183,7 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/leave", v3mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -190,7 +197,7 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/unpeek", v3mux.Handle("/rooms/{roomID}/unpeek",
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -201,7 +208,7 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/ban", v3mux.Handle("/rooms/{roomID}/ban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -210,7 +217,7 @@ func Setup(
return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/invite", v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -222,7 +229,7 @@ func Setup(
return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/kick", v3mux.Handle("/rooms/{roomID}/kick",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -231,7 +238,7 @@ func Setup(
return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/unban", v3mux.Handle("/rooms/{roomID}/unban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -240,7 +247,7 @@ func Setup(
return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}", v3mux.Handle("/rooms/{roomID}/send/{eventType}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -249,7 +256,7 @@ func Setup(
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -260,7 +267,7 @@ func Setup(
nil, cfg, rsAPI, transactionsCache) nil, cfg, rsAPI, transactionsCache)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/event/{eventID}", v3mux.Handle("/rooms/{roomID}/event/{eventID}",
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -270,7 +277,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -278,7 +285,7 @@ func Setup(
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -286,7 +293,7 @@ func Setup(
return GetAliases(req, rsAPI, device, vars["roomID"]) return GetAliases(req, rsAPI, device, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -297,7 +304,7 @@ func Setup(
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -306,7 +313,7 @@ func Setup(
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -318,7 +325,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -329,21 +336,21 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return Register(req, userAPI, accountDB, cfg) return Register(req, userAPI, accountDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return RegisterAvailable(req, cfg, accountDB) return RegisterAvailable(req, cfg, accountDB)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -353,7 +360,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -363,7 +370,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -372,7 +379,7 @@ func Setup(
return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI) return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI)
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/directory/list/room/{roomID}", v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -382,7 +389,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// TODO: Add AS support // TODO: Add AS support
r0mux.Handle("/directory/list/room/{roomID}", v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -391,25 +398,25 @@ func Setup(
return SetVisibility(req, rsAPI, device, vars["roomID"]) return SetVisibility(req, rsAPI, device, vars["roomID"])
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/publicRooms", v3mux.Handle("/publicRooms",
httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg) return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout", v3mux.Handle("/logout",
httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Logout(req, userAPI, device) return Logout(req, userAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout/all", v3mux.Handle("/logout/all",
httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return LogoutAll(req, userAPI, device) return LogoutAll(req, userAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/typing/{userID}", v3mux.Handle("/rooms/{roomID}/typing/{userID}",
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -421,7 +428,7 @@ func Setup(
return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI) return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}", v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -430,7 +437,7 @@ func Setup(
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -440,7 +447,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/sendToDevice/{eventType}/{txnID}", v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -465,7 +472,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/account/whoami", v3mux.Handle("/account/whoami",
httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -474,7 +481,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/password", v3mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -483,7 +490,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/account/deactivate", v3mux.Handle("/account/deactivate",
httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -494,7 +501,7 @@ func Setup(
// Stub endpoints required by Element // Stub endpoints required by Element
r0mux.Handle("/login", v3mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -503,7 +510,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/auth/{authType}/fallback/web", v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
vars := mux.Vars(req) vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg) return AuthFallback(w, req, vars["authType"], cfg)
@ -512,7 +519,7 @@ func Setup(
// Push rules // Push rules
r0mux.Handle("/pushrules", v3mux.Handle("/pushrules",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -521,13 +528,13 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushrules/", v3mux.Handle("/pushrules/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAllPushRules(req.Context(), device, psAPI) return GetAllPushRules(req.Context(), device, psAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushrules/", v3mux.Handle("/pushrules/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -536,7 +543,7 @@ func Setup(
}), }),
).Methods(http.MethodPut) ).Methods(http.MethodPut)
r0mux.Handle("/pushrules/{scope}/", v3mux.Handle("/pushrules/{scope}/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -546,7 +553,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushrules/{scope}", v3mux.Handle("/pushrules/{scope}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -555,7 +562,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushrules/{scope:[^/]+/?}", v3mux.Handle("/pushrules/{scope:[^/]+/?}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -564,7 +571,7 @@ func Setup(
}), }),
).Methods(http.MethodPut) ).Methods(http.MethodPut)
r0mux.Handle("/pushrules/{scope}/{kind}/", v3mux.Handle("/pushrules/{scope}/{kind}/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -574,7 +581,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushrules/{scope}/{kind}", v3mux.Handle("/pushrules/{scope}/{kind}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -583,7 +590,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}", v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -592,7 +599,7 @@ func Setup(
}), }),
).Methods(http.MethodPut) ).Methods(http.MethodPut)
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -602,7 +609,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -616,7 +623,7 @@ func Setup(
}), }),
).Methods(http.MethodPut) ).Methods(http.MethodPut)
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -626,7 +633,7 @@ func Setup(
}), }),
).Methods(http.MethodDelete) ).Methods(http.MethodDelete)
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -636,7 +643,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -648,7 +655,7 @@ func Setup(
// Element user settings // Element user settings
r0mux.Handle("/profile/{userID}", v3mux.Handle("/profile/{userID}",
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -658,7 +665,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url", v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -668,7 +675,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url", v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -683,7 +690,7 @@ func Setup(
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method // PUT requests, so we need to allow this method
r0mux.Handle("/profile/{userID}/displayname", v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -693,7 +700,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/displayname", v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -708,13 +715,13 @@ func Setup(
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method // PUT requests, so we need to allow this method
r0mux.Handle("/account/3pid", v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAssociated3PIDs(req, accountDB, device) return GetAssociated3PIDs(req, accountDB, device)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/3pid", v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) return CheckAndSave3PIDAssociation(req, accountDB, device, cfg)
}), }),
@ -726,14 +733,14 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
return RequestEmailToken(req, accountDB, cfg) return RequestEmailToken(req, accountDB, cfg)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// Element logs get flooded unless this is handled // Element logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status", v3mux.Handle("/presence/{userID}/status",
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -746,7 +753,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/voip/turnServer", v3mux.Handle("/voip/turnServer",
httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -755,7 +762,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/thirdparty/protocols", v3mux.Handle("/thirdparty/protocols",
httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
// TODO: Return the third party protcols // TODO: Return the third party protcols
return util.JSONResponse{ return util.JSONResponse{
@ -765,7 +772,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/initialSync", v3mux.Handle("/rooms/{roomID}/initialSync",
httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse {
// TODO: Allow people to peek into rooms. // TODO: Allow people to peek into rooms.
return util.JSONResponse{ return util.JSONResponse{
@ -775,7 +782,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}", v3mux.Handle("/user/{userID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -785,7 +792,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -795,7 +802,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}", v3mux.Handle("/user/{userID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -805,7 +812,7 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -815,7 +822,7 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/admin/whois/{userID}", v3mux.Handle("/admin/whois/{userID}",
httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -825,7 +832,7 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/openid/request_token", v3mux.Handle("/user/{userID}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -838,7 +845,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user_directory/search", v3mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -863,7 +870,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/members", v3mux.Handle("/rooms/{roomID}/members",
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -873,7 +880,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/joined_members", v3mux.Handle("/rooms/{roomID}/joined_members",
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -883,7 +890,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/read_markers", v3mux.Handle("/rooms/{roomID}/read_markers",
httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -896,7 +903,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/forget", v3mux.Handle("/rooms/{roomID}/forget",
httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -909,13 +916,13 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/devices", v3mux.Handle("/devices",
httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, userAPI, device) return GetDevicesByLocalpart(req, userAPI, device)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -925,7 +932,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -935,7 +942,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -945,7 +952,7 @@ func Setup(
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/delete_devices", v3mux.Handle("/delete_devices",
httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return DeleteDevices(req, userAPI, device) return DeleteDevices(req, userAPI, device)
}), }),
@ -957,13 +964,13 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushers", v3mux.Handle("/pushers",
httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetPushers(req, device, psAPI) return GetPushers(req, device, psAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/pushers/set", v3mux.Handle("/pushers/set",
httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -973,7 +980,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest // Stub implementations for sytest
r0mux.Handle("/events", v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"chunk": []interface{}{}, "chunk": []interface{}{},
@ -983,7 +990,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/initialSync", v3mux.Handle("/initialSync",
httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"end": "", "end": "",
@ -991,7 +998,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags",
httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -1001,7 +1008,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -1011,7 +1018,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -1021,7 +1028,7 @@ func Setup(
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/capabilities", v3mux.Handle("/capabilities",
httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -1064,11 +1071,11 @@ func Setup(
return CreateKeyBackupVersion(req, userAPI, device) return CreateKeyBackupVersion(req, userAPI, device)
}) })
r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) v3mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut)
r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) v3mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete)
r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
@ -1160,9 +1167,9 @@ func Setup(
return UploadBackupKeys(req, userAPI, device, version, &keyReq) return UploadBackupKeys(req, userAPI, device, version, &keyReq)
}) })
r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut)
unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
@ -1190,9 +1197,9 @@ func Setup(
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
}) })
r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
@ -1210,34 +1217,34 @@ func Setup(
return UploadCrossSigningDeviceSignatures(req, keyAPI, device) return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
}) })
r0mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
// Supplying a device ID is deprecated. // Supplying a device ID is deprecated.
r0mux.Handle("/keys/upload/{deviceID}", v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/upload", v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/query", v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, keyAPI, device) return QueryKeys(req, keyAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/claim", v3mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return ClaimKeys(req, keyAPI) return ClaimKeys(req, keyAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r

View file

@ -20,7 +20,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -33,7 +33,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer // sends the typing events to client API typingProducer
func SendTyping( func SendTyping(
req *http.Request, device *userapi.Device, roomID string, req *http.Request, device *userapi.Device, roomID string,
userID string, accountDB accounts.Database, userID string, accountDB userdb.Database,
eduAPI api.EDUServerInputAPI, eduAPI api.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -40,7 +40,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements: // RequestEmailToken implements:
// POST /account/3pid/email/requestToken // POST /account/3pid/email/requestToken
// POST /register/email/requestToken // POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse { func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse {
var body threepid.EmailAssociationRequest var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{ JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_IN_USE", ErrCode: "M_THREEPID_IN_USE",
Err: accounts.Err3PIDInUse.Error(), Err: userdb.Err3PIDInUse.Error(),
}, },
} }
} }
@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
// CheckAndSave3PIDAssociation implements POST /account/3pid // CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation( func CheckAndSave3PIDAssociation(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, accountDB userdb.Database, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest var body threepid.EmailAssociationCheckRequest
@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid // GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, accountDB userdb.Database, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -170,7 +170,7 @@ func GetAssociated3PIDs(
} }
// Forget3PID implements POST /account/3pid/delete // Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse { func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse {
var body authtypes.ThreePID var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr

View file

@ -22,6 +22,8 @@ import (
// whoamiResponse represents an response for a `whoami` request // whoamiResponse represents an response for a `whoami` request
type whoamiResponse struct { type whoamiResponse struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
IsGuest bool `json:"is_guest"`
} }
// Whoami implements `/account/whoami` which enables client to query their account user id. // Whoami implements `/account/whoami` which enables client to query their account user id.
@ -29,6 +31,10 @@ type whoamiResponse struct {
func Whoami(req *http.Request, device *api.Device) util.JSONResponse { func Whoami(req *http.Request, device *api.Device) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: whoamiResponse{UserID: device.UserID}, JSON: whoamiResponse{
UserID: device.UserID,
DeviceID: device.ID,
IsGuest: device.AccountType == api.AccountTypeGuest,
},
} }
} }

View file

@ -29,7 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -87,7 +87,7 @@ var (
func CheckAndProcessInvite( func CheckAndProcessInvite(
ctx context.Context, ctx context.Context,
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, db accounts.Database, rsAPI api.RoomserverInternalAPI, db userdb.Database,
roomID string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed. // Returns an error if a check or a request failed.
func queryIDServer( func queryIDServer(
ctx context.Context, ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil { if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite( func queryIDServerStoreInvite(
ctx context.Context, ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) { ) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name // Retrieve the sender's profile to get their display name

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
) )
const usage = `Usage: %s const usage = `Usage: %s
@ -77,9 +77,14 @@ func main() {
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ accountDB, err := userdb.NewDatabase(
&config.DatabaseOptions{
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) },
cfg.Global.ServerName, bcrypt.DefaultCost,
cfg.UserAPI.OpenIDTokenLifetimeMS,
api.DefaultLoginTokenLifetime,
)
if err != nil { if err != nil {
logrus.Fatalln("Failed to connect to the database:", err.Error()) logrus.Fatalln("Failed to connect to the database:", err.Error())
} }

View file

@ -127,7 +127,6 @@ func main() {
cfg.FederationAPI.FederationMaxRetries = 6 cfg.FederationAPI.FederationMaxRetries = 6
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -160,7 +160,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -80,7 +80,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -164,7 +164,6 @@ func startup() {
cfg.Defaults(true) cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"

View file

@ -168,7 +168,6 @@ func main() {
cfg.Defaults(true) cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"

View file

@ -32,7 +32,6 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI)
} }
cfg.Global.TrustedIDServers = []string{ cfg.Global.TrustedIDServers = []string{
"matrix.org", "matrix.org",

View file

@ -8,12 +8,11 @@ import (
"log" "log"
"os" "os"
pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/pressly/goose" "github.com/pressly/goose"
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -26,8 +25,7 @@ const (
RoomServer = "roomserver" RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver" SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi" SyncAPI = "syncapi"
UserAPIAccounts = "userapi_accounts" UserAPI = "userapi"
UserAPIDevices = "userapi_devices"
) )
var ( var (
@ -35,7 +33,7 @@ var (
flags = flag.NewFlagSet("goose", flag.ExitOnError) flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name") component = flags.String("component", "", "dendrite component name")
knownDBs = []string{ knownDBs = []string{
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices, AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
} }
) )
@ -143,18 +141,14 @@ Commands:
func loadSQLiteDeltas(component string) { func loadSQLiteDeltas(component string) {
switch component { switch component {
case UserAPIAccounts: case UserAPI:
slaccounts.LoadFromGoose() slusers.LoadFromGoose()
case UserAPIDevices:
sldevices.LoadFromGoose()
} }
} }
func loadPostgresDeltas(component string) { func loadPostgresDeltas(component string) {
switch component { switch component {
case UserAPIAccounts: case UserAPI:
pgaccounts.LoadFromGoose() pgusers.LoadFromGoose()
case UserAPIDevices:
pgdevices.LoadFromGoose()
} }
} }

View file

@ -142,6 +142,10 @@ client_api:
# using the registration shared secret below. # using the registration shared secret below.
registration_disabled: false registration_disabled: false
# Prevents new guest accounts from being created. Guest registration is also
# disabled implicitly by setting 'registration_disabled' above.
guests_disabled: true
# If set, allows registration by anyone who knows the shared secret, regardless of # If set, allows registration by anyone who knows the shared secret, regardless of
# whether registration is otherwise disabled. # whether registration is otherwise disabled.
registration_shared_secret: "" registration_shared_secret: ""

View file

@ -7,14 +7,6 @@ import (
) )
const ( const (
RoomServerStateKeyNIDsCacheName = "roomserver_statekey_nids"
RoomServerStateKeyNIDsCacheMaxEntries = 1024
RoomServerStateKeyNIDsCacheMutable = false
RoomServerEventTypeNIDsCacheName = "roomserver_eventtype_nids"
RoomServerEventTypeNIDsCacheMaxEntries = 64
RoomServerEventTypeNIDsCacheMutable = false
RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false RoomServerRoomIDsCacheMutable = false
@ -29,44 +21,10 @@ type RoomServerCaches interface {
// RoomServerNIDsCache contains the subset of functions needed for // RoomServerNIDsCache contains the subset of functions needed for
// a roomserver NID cache. // a roomserver NID cache.
type RoomServerNIDsCache interface { type RoomServerNIDsCache interface {
GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool)
StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID)
GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool)
StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID)
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
} }
func (c Caches) GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) {
val, found := c.RoomServerStateKeyNIDs.Get(stateKey)
if found && val != nil {
if stateKeyNID, ok := val.(types.EventStateKeyNID); ok {
return stateKeyNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) {
c.RoomServerStateKeyNIDs.Set(stateKey, nid)
}
func (c Caches) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) {
val, found := c.RoomServerEventTypeNIDs.Get(eventType)
if found && val != nil {
if eventTypeNID, ok := val.(types.EventTypeNID); ok {
return eventTypeNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) {
c.RoomServerEventTypeNIDs.Set(eventType, nid)
}
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID)))
if found && val != nil { if found && val != nil {

View file

@ -6,8 +6,6 @@ package caching
type Caches struct { type Caches struct {
RoomVersions Cache // RoomVersionCache RoomVersions Cache // RoomVersionCache
ServerKeys Cache // ServerKeyCache ServerKeys Cache // ServerKeyCache
RoomServerStateKeyNIDs Cache // RoomServerNIDsCache
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerRoomIDs Cache // RoomServerNIDsCache RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache RoomInfos Cache // RoomInfoCache

View file

@ -28,24 +28,6 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
roomServerStateKeyNIDs, err := NewInMemoryLRUCachePartition(
RoomServerStateKeyNIDsCacheName,
RoomServerStateKeyNIDsCacheMutable,
RoomServerStateKeyNIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerEventTypeNIDs, err := NewInMemoryLRUCachePartition(
RoomServerEventTypeNIDsCacheName,
RoomServerEventTypeNIDsCacheMutable,
RoomServerEventTypeNIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerRoomIDs, err := NewInMemoryLRUCachePartition( roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheName,
RoomServerRoomIDsCacheMutable, RoomServerRoomIDsCacheMutable,
@ -74,15 +56,12 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
return nil, err return nil, err
} }
go cacheCleaner( go cacheCleaner(
roomVersions, serverKeys, roomServerStateKeyNIDs, roomVersions, serverKeys, roomServerRoomIDs,
roomServerEventTypeNIDs, roomServerRoomIDs,
roomInfos, federationEvents, roomInfos, federationEvents,
) )
return &Caches{ return &Caches{
RoomVersions: roomVersions, RoomVersions: roomVersions,
ServerKeys: serverKeys, ServerKeys: serverKeys,
RoomServerStateKeyNIDs: roomServerStateKeyNIDs,
RoomServerEventTypeNIDs: roomServerEventTypeNIDs,
RoomServerRoomIDs: roomServerRoomIDs, RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos, RoomInfos: roomInfos,
FederationEvents: federationEvents, FederationEvents: federationEvents,

View file

@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
cfg.RoomServer.Database.ConnectionString = config.DataSource(database) cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database) cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database)
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
cfg.EDUServer.InternalAPI.Listen = assignAddress() cfg.EDUServer.InternalAPI.Listen = assignAddress()

View file

@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
} }
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil) msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err), Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domain := string(serverName) domain := string(serverName)
// query local devices // query local devices
if serverName == a.ThisServer { if serverName == a.ThisServer {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err), Err: fmt.Sprintf("failed to query local device keys: %s", err),
@ -525,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
) error { ) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote. // if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil { if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
@ -554,10 +554,58 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
} }
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
// get a list of devices from the user API that actually exist, as
// we won't store keys for devices that don't exist
uapidevices := &userapi.QueryDevicesResponse{}
if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
return
}
if !uapidevices.UserExists {
res.Error = &api.KeyError{
Err: fmt.Sprintf("user %q does not exist", req.UserID),
}
return
}
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
for _, key := range uapidevices.Devices {
existingDeviceMap[key.ID] = struct{}{}
}
// Get all of the user existing device keys so we can check for changes.
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
}
return
}
// Work out whether we have device keys in the keyserver for devices that
// no longer exist in the user API. This is mostly an exercise to ensure
// that we keep some integrity between the two.
var toClean []gomatrixserverlib.KeyID
for _, k := range existingKeys {
if _, ok := existingDeviceMap[k.DeviceID]; !ok {
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
}
}
if len(toClean) > 0 {
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
} else {
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
}
}
var keysToStore []api.DeviceMessage var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key // assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys { for _, key := range req.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) var serverName gomatrixserverlib.ServerName
_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
if err != nil { if err != nil {
continue // ignore invalid users continue // ignore invalid users
} }
@ -568,6 +616,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
keysToStore = append(keysToStore, key.WithStreamID(0)) keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking continue // deleted keys don't need sanity checking
} }
// check that the device in question actually exists in the user
// API before we try and store a key for it
if _, ok := existingDeviceMap[key.DeviceID]; !ok {
continue
}
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID { if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
@ -583,29 +636,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}) })
} }
// get existing device keys so we can check for changes
existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore {
existingKeys[i] = api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID,
},
}
}
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
}
return
}
if req.OnlyDisplayNameUpdates { if req.OnlyDisplayNameUpdates {
// add the display name field from keysToStore into existingKeys // add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore) keysToStore = appendDisplayNames(existingKeys, keysToStore)
} }
// store the device keys and emit changes // store the device keys and emit changes
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),

View file

@ -53,7 +53,7 @@ type Database interface {
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device. // cross-signing signatures relating to that device.

View file

@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -73,6 +76,7 @@ type deviceKeysStatements struct {
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt countStreamIDsForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt deleteDeviceKeysStmt *sql.Stmt
@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err return nil, err
} }
@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err return err
} }
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
}) })
} }
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
} }
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {

View file

@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -69,6 +72,7 @@ type deviceKeysStatements struct {
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt deleteAllDeviceKeysStmt *sql.Stmt
@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err return nil, err
} }
@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err return err
} }
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool) deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs { for _, d := range deviceIDs {
deviceIDMap[d] = true deviceIDMap[d] = true
} }
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
} }
// Querying for device keys returns the latest stream IDs // Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil { if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err) t.Fatalf("DeviceKeysForUser returned error: %s", err)
} }

View file

@ -38,7 +38,7 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
} }

View file

@ -59,23 +59,12 @@ func (d *Database) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string, ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID) result := make(map[string]types.EventTypeNID)
remaining := []string{} nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes)
for _, eventType := range eventTypes {
if nid, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
result[eventType] = nid
} else {
remaining = append(remaining, eventType)
}
}
if len(remaining) > 0 {
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for eventType, nid := range nids { for eventType, nid := range nids {
result[eventType] = nid result[eventType] = nid
d.Cache.StoreRoomServerEventTypeNID(eventType, nid)
}
} }
return result, nil return result, nil
} }
@ -96,23 +85,12 @@ func (d *Database) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string, ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID) result := make(map[string]types.EventStateKeyNID)
remaining := []string{} nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
for _, eventStateKey := range eventStateKeys {
if nid, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
result[eventStateKey] = nid
} else {
remaining = append(remaining, eventStateKey)
}
}
if len(remaining) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for eventStateKey, nid := range nids { for eventStateKey, nid := range nids {
result[eventStateKey] = nid result[eventStateKey] = nid
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, nid)
}
} }
return result, nil return result, nil
} }
@ -718,9 +696,6 @@ func (d *Database) assignRoomNID(
func (d *Database) assignEventTypeNID( func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string, ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) { ) (types.EventTypeNID, error) {
if eventTypeNID, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
return eventTypeNID, nil
}
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -731,18 +706,12 @@ func (d *Database) assignEventTypeNID(
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
} }
} }
if err == nil {
d.Cache.StoreRoomServerEventTypeNID(eventType, eventTypeNID)
}
return eventTypeNID, err return eventTypeNID, err
} }
func (d *Database) assignStateKeyNID( func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
if eventStateKeyNID, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
return eventStateKeyNID, nil
}
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -753,9 +722,6 @@ func (d *Database) assignStateKeyNID(
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
} }
} }
if err == nil {
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, eventStateKeyNID)
}
return eventStateKeyNID, err return eventStateKeyNID, err
} }

View file

@ -39,7 +39,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -290,8 +290,14 @@ func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() accounts.Database { func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS) db, err := userdb.NewDatabase(
&b.Cfg.UserAPI.AccountDatabase,
b.Cfg.Global.ServerName,
b.Cfg.UserAPI.BCryptCost,
b.Cfg.UserAPI.OpenIDTokenLifetimeMS,
userapi.DefaultLoginTokenLifetime,
)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db") logrus.WithError(err).Panicf("failed to connect to accounts db")
} }

View file

@ -18,6 +18,10 @@ type ClientAPI struct {
// If set, allows registration by anyone who also has the shared // If set, allows registration by anyone who also has the shared
// secret, even if registration is otherwise disabled. // secret, even if registration is otherwise disabled.
RegistrationSharedSecret string `yaml:"registration_shared_secret"` RegistrationSharedSecret string `yaml:"registration_shared_secret"`
// If set, prevents guest accounts from being created. Only takes
// effect if registration is enabled, otherwise guests registration
// is forbidden either way.
GuestsDisabled bool `yaml:"guests_disabled"`
// Boolean stating whether catpcha registration is enabled // Boolean stating whether catpcha registration is enabled
// and required // and required

View file

@ -16,9 +16,6 @@ type UserAPI struct {
// The Account database stores the login details and account information // The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI. // for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"` AccountDatabase DatabaseOptions `yaml:"account_database"`
// The Device database stores session information for the devices of logged
// in local users. It is accessed by the UserAPI.
DeviceDatabase DatabaseOptions `yaml:"device_database"`
} }
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) {
c.InternalAPI.Listen = "http://localhost:7781" c.InternalAPI.Listen = "http://localhost:7781"
c.InternalAPI.Connect = "http://localhost:7781" c.InternalAPI.Connect = "http://localhost:7781"
c.AccountDatabase.Defaults(10) c.AccountDatabase.Defaults(10)
c.DeviceDatabase.Defaults(10)
if generate { if generate {
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
} }
c.BCryptCost = bcrypt.DefaultCost c.BCryptCost = bcrypt.DefaultCost
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
} }

View file

@ -31,7 +31,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/dendrite/syncapi"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -39,7 +39,7 @@ import (
// all components of Dendrite, for use in monolith mode. // all components of Dendrite, for use in monolith mode.
type Monolith struct { type Monolith struct {
Config *config.Dendrite Config *config.Dendrite
AccountDB accounts.Database AccountDB userdb.Database
KeyRing *gomatrixserverlib.KeyRing KeyRing *gomatrixserverlib.KeyRing
Client *gomatrixserverlib.Client Client *gomatrixserverlib.Client
FedClient *gomatrixserverlib.FederationClient FedClient *gomatrixserverlib.FederationClient

View file

@ -39,14 +39,14 @@ func Setup(
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
cfg *config.SyncAPI, cfg *config.SyncAPI,
) { ) {
r0mux := csMux.PathPrefix("/r0").Subrouter() v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
// TODO: Add AS support for all handlers below. // TODO: Add AS support for all handlers below.
r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device) return srp.OnIncomingSyncRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -54,7 +54,7 @@ func Setup(
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp) return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter", v3mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -64,7 +64,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}", v3mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -74,7 +74,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingKeyChangeRequest(req, device) return srp.OnIncomingKeyChangeRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
} }

View file

@ -19,6 +19,13 @@ import (
"time" "time"
) )
// DefaultLoginTokenLifetime determines how old a valid token may be.
//
// NOTSPEC: The current spec says "SHOULD be limited to around five
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
const DefaultLoginTokenLifetime = 2 * time.Minute
type LoginTokenInternalAPI interface { type LoginTokenInternalAPI interface {
// PerformLoginTokenCreation creates a new login token and associates it with the provided data. // PerformLoginTokenCreation creates a new login token and associates it with the provided data.
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error

View file

@ -31,13 +31,11 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/devices"
) )
type UserInternalAPI struct { type UserInternalAPI struct {
AccountDB accounts.Database DB storage.Database
DeviceDB devices.Database
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS // AppServices is the list of all registered AS
AppServices []config.ApplicationService AppServices []config.ApplicationService
@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" { if req.DataType == "" {
return fmt.Errorf("data type must not be empty") return fmt.Errorf("data type must not be empty")
} }
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {
@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil return nil
} }
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err return err
} }
@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
} }
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err return err
} }
res.PasswordUpdated = true res.PasswordUpdated = true
@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID, "device_id": req.DeviceID,
"display_name": req.DeviceDisplayName, "display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation") }).Info("PerformDeviceCreation")
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil { if err != nil {
return err return err
} }
@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 { if len(req.DeviceIDs) == 0 {
var devices []api.Device var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices { for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID) deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
} }
} else { } else {
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
} }
if err != nil { if err != nil {
return err return err
@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil { if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
} }
return nil return nil
@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err return err
} }
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
res.DeviceExists = false res.DeviceExists = false
return nil return nil
@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil return nil
} }
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err return err
@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
} }
prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
} }
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil { if err != nil {
return err return err
} }
@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
} }
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil { if err != nil {
return err return err
} }
@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
} }
devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil { if err != nil {
return err return err
} }
res.UserExists = true
res.Devices = devs res.Devices = devs
return nil return nil
} }
@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
if req.DataType != "" { if req.DataType != "" {
var data json.RawMessage var data json.RawMessage
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil { if err != nil {
return err return err
} }
@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
return nil return nil
} }
global, rooms, err := a.AccountDB.GetAccountData(ctx, local) global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil { if err != nil {
return err return err
} }
@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil return nil
} }
device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if err != nil { if err != nil {
return err return err
} }
acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart) acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil { if err != nil {
return err return err
} }
@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered // Verify that the user is registered
account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or // Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces // it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
err := a.AccountDB.DeactivateAccount(ctx, req.Localpart) err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil res.AccountDeactivated = err == nil
return err return err
} }
@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24) token := util.RandomString(24)
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID) exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{ res.Token = api.OpenIDToken{
Token: token, Token: token,
@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token) openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil { if err != nil {
return err return err
} }
@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
return nil return nil
} }
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err) res.Error = fmt.Sprintf("failed to delete backup: %s", err)
} }
@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
// Create metadata // Create metadata
if req.Version == "" { if req.Version == "" {
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err) res.Error = fmt.Sprintf("failed to create backup: %s", err)
} }
@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
// Update metadata // Update metadata
if len(req.Keys.Rooms) == 0 { if len(req.Keys.Rooms) == 0 {
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err) res.Error = fmt.Sprintf("failed to update backup: %s", err)
} }
@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version // you can only upload keys for the CURRENT version
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err) res.Error = fmt.Sprintf("failed to query version: %s", err)
return return
@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}) })
} }
} }
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err) res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return return
@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
} }
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version res.Version = version
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted res.Exists = !deleted
if !req.ReturnKeys { if !req.ReturnKeys {
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err) res.Error = fmt.Sprintf("failed to count keys: %s", err)
} }
return return
} }
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err) res.Error = fmt.Sprintf("failed to query keys: %s", err)
return return

View file

@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
} }
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil { if err != nil {
return err return err
} }
@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
// PerformLoginTokenDeletion ensures the token doesn't exist. // PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion") util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
return a.DeviceDB.RemoveLoginToken(ctx, req.Token) return a.DB.RemoveLoginToken(ctx, req.Token)
} }
// QueryLoginToken returns the data associated with a login token. If // QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil. // the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil { if err != nil {
res.Data = nil res.Data = nil
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
} }
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil res.Data = nil
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil

View file

@ -1,515 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, sqlutil.ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
var account *api.Account
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = d.hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists
}
return nil, err
}
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.serverName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err
}
return account, nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart)
}
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
if err != nil {
return err
}
if deleted {
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
if err != nil {
return err
}
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
}
}
// if we shouldn't replace the key we do nothing with it
continue
}
}
// if we're here, either the room or session are new, either way, we insert
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
}
}
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}

View file

@ -1,52 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}

View file

@ -1,270 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,271 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
writer sqlutil.Writer
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, writer, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,42 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters. loginTokenLifetime determines how long a
// login token from CreateLoginToken is valid.
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,39 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(
dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName,
loginTokenLifetime time.Duration,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package storage
import ( import (
"context" "context"
@ -60,6 +60,35 @@ type Database interface {
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -56,19 +57,20 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
_, err = db.Exec(accountDataSchema) s := &accountDataStatements{}
_, err := db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.insertAccountDataStmt, insertAccountDataSQL},
{&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL},
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData(
return return
} }
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
return global, rooms, rows.Err() return global, rooms, rows.Err()
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -78,14 +79,15 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) execSchema(db *sql.DB) error { func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
s := &accountsStatements{
serverName: serverName,
}
_, err := db.Exec(accountsSchema) _, err := db.Exec(accountsSchema)
return err if err != nil {
} return nil, err
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { return s, sqlutil.StatementList{
s.serverName = server
return sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL}, {&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL}, {&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL},
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount(
}, nil }, nil
} }
func (s *accountsStatements) updatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart, passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return return
} }
func (s *accountsStatements) deactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return return
} }
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return return
} }
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil return &acc, nil
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt

View file

@ -5,13 +5,8 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) { func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
} }

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -111,50 +112,32 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) execSchema(db *sql.DB) error { func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
serverName: serverName,
}
_, err := db.Exec(devicesSchema) _, err := db.Exec(devicesSchema)
return err if err != nil {
} return nil, err
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
} }
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { return s, sqlutil.StatementList{
return {&s.insertDeviceStmt, insertDeviceSQL},
} {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { {&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
return {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
} {&s.updateDeviceNameStmt, updateDeviceNameSQL},
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { {&s.deleteDeviceStmt, deleteDeviceSQL},
return {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
} {&s.deleteDevicesStmt, deleteDevicesSQL},
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { {&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
return {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
} }.Prepare(db)
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
} }
// insertDevice creates a new device. Returns an error if any device with the same access token already exists. // insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice( func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
@ -176,7 +159,7 @@ func (s *devicesStatements) insertDevice(
} }
// deleteDevice removes a single device by id and user localpart. // deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -186,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
// deleteDevices removes a single or multiple devices by ids and user localpart. // deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed. // Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
@ -196,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the // deleteDevicesByLocalpart removes all devices for the
// given user localpart. // given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -204,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err return err
} }
func (s *devicesStatements) updateDeviceName( func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -212,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
return err return err
} }
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string, ctx context.Context, accessToken string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -228,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user // selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID // localpart and deviceID
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -245,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err return &dev, err
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil { if err != nil {
return nil, err return nil, err
@ -268,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
@ -310,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000 lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupTableSchema = ` const keyBackupTableSchema = `
@ -72,12 +73,13 @@ type keyBackupStatements struct {
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
} }
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
_, err = db.Exec(keyBackupTableSchema) s := &keyBackupStatements{}
_, err := db.Exec(keyBackupTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL}, {&s.countKeysStmt, countKeysSQL},
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s keyBackupStatements) countKeys( func (s keyBackupStatements) CountKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) { ) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return return
} }
func (s *keyBackupStatements) insertBackupKey( func (s *keyBackupStatements) InsertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
return return
} }
func (s *keyBackupStatements) updateBackupKey( func (s *keyBackupStatements) UpdateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
return return
} }
func (s *keyBackupStatements) selectKeys( func (s *keyBackupStatements) SelectKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomID( func (s *keyBackupStatements) SelectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string, ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)

View file

@ -22,6 +22,7 @@ import (
"strconv" "strconv"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupVersionTableSchema = ` const keyBackupVersionTableSchema = `
@ -69,12 +70,13 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt updateKeyBackupETagStmt *sql.Stmt
} }
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
_, err = db.Exec(keyBackupVersionTableSchema) s := &keyBackupVersionStatements{}
_, err := db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.insertKeyBackupStmt, insertKeyBackupSQL},
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *keyBackupVersionStatements) insertKeyBackup( func (s *keyBackupVersionStatements) InsertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) { ) (version string, err error) {
var versionInt int64 var versionInt int64
@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return strconv.FormatInt(versionInt, 10), err return strconv.FormatInt(versionInt, 10), err
} }
func (s *keyBackupVersionStatements) updateKeyBackupAuthData( func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err return err
} }
func (s *keyBackupVersionStatements) updateKeyBackupETag( func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string, ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
return err return err
} }
func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) DeleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) { ) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
return ra == 1, nil return ra == 1, nil
} }
func (s *keyBackupVersionStatements) selectKeyBackup( func (s *keyBackupVersionStatements) SelectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64 var versionInt int64

View file

@ -21,18 +21,11 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type loginTokenStatements struct { const loginTokenSchema = `
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt
}
// execSchema ensures tables and indices exist.
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS login_tokens ( CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user -- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
@ -45,21 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens (
-- This index allows efficient garbage collection of expired tokens. -- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`) `
return err
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectStmt *sql.Stmt
} }
// prepare runs statement preparation. func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
func (s *loginTokenStatements) prepare(db *sql.DB) error { s := &loginTokenStatements{}
return sqlutil.StatementList{ _, err := db.Exec(loginTokenSchema)
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, if err != nil {
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, return nil, err
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, }
return s, sqlutil.StatementList{
{&s.insertStmt, insertLoginTokenSQL},
{&s.deleteStmt, deleteLoginTokenSQL},
{&s.selectStmt, selectLoginTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insert adds an already generated token to the database. // insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt) stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err return err
@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
// //
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient. // The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt) stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil { if err != nil {
@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
} }
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. // selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
); );
` `
const insertTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct { type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
_, err = db.Exec(openIDTokenSchema) s := &openIDTokenStatements{
if err != nil { serverName: serverName,
return
} }
s.serverName = server _, err := db.Exec(openIDTokenSchema)
return sqlutil.StatementList{ if err != nil {
{&s.insertTokenStmt, insertTokenSQL}, return nil, err
{&s.selectTokenStmt, selectTokenSQL}, }
return s, sqlutil.StatementList{
{&s.insertTokenStmt, insertOpenIDTokenSQL},
{&s.selectTokenStmt, selectOpenIDTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insertToken inserts a new OpenID Connect token to the DB. // insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if the token already exists. // Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string,
@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken(
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found // Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributes( func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
ctx context.Context, ctx context.Context,
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const profilesSchema = ` const profilesSchema = `
@ -59,12 +60,13 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
_, err = db.Exec(profilesSchema) s := &profilesStatements{}
_, err := db.Exec(profilesSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertProfileStmt, insertProfileSQL}, {&s.insertProfileStmt, insertProfileSQL},
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
{&s.setAvatarURLStmt, setAvatarURLSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL},
@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }
func (s *profilesStatements) selectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile var profile authtypes.Profile
@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart(
return &profile, nil return &profile, nil
} }
func (s *profilesStatements) setAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return return
} }
func (s *profilesStatements) setDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return return
} }
func (s *profilesStatements) selectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(
ctx context.Context, searchString string, limit int, ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) { ) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile var profiles []authtypes.Profile

View file

@ -0,0 +1,105 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
if _, err = db.Exec(accountsSchema); err != nil {
// do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err
}
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
}
devicesTable, err := NewPostgresDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
}
keyBackupTable, err := NewPostgresKeyBackupTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err)
}
keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err)
}
loginTokenTable, err := NewPostgresLoginTokenTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err)
}
openIDTable, err := NewPostgresOpenIDTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err)
}
profilesTable, err := NewPostgresProfilesTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err)
}
threePIDTable, err := NewPostgresThreePIDTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
Devices: devicesTable,
KeyBackups: keyBackupTable,
KeyBackupVersions: keyBackupVersionTable,
LoginTokens: loginTokenTable,
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewDummyWriter(),
LoginTokenLifetime: loginTokenLifetime,
BcryptCost: bcryptCost,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -58,12 +59,13 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt
} }
func (s *threepidStatements) prepare(db *sql.DB) (err error) { func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
_, err = db.Exec(threepidSchema) s := &threepidStatements{}
_, err := db.Exec(threepidSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
{&s.insertThreePIDStmt, insertThreePIDSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL},
@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *threepidStatements) selectLocalpartForThreePID( func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
return return
} }
func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
return return
} }
func (s *threepidStatements) insertThreePID( func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID(
return return
} }
func (s *threepidStatements) deleteThreePID( func (s *threepidStatements) DeleteThreePID(
ctx context.Context, threepid string, medium string) (err error) { ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium)
return return
} }

View file

@ -12,16 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package sqlite3 package shared
import ( import (
"context" "context"
"crypto/rand"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -30,102 +31,48 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables"
) )
// Database represents an account database // Database represents an account database
type Database struct { type Database struct {
db *sql.DB DB *sql.DB
writer sqlutil.Writer Writer sqlutil.Writer
Accounts tables.AccountsTable
sqlutil.PartitionOffsetStatements Profiles tables.ProfileTable
accounts accountsStatements AccountDatas tables.AccountDataTable
profiles profilesStatements ThreePIDs tables.ThreePIDTable
accountDatas accountDataStatements OpenIDTokens tables.OpenIDTable
threepids threepidStatements KeyBackups tables.KeyBackupTable
openIDTokens tokenStatements KeyBackupVersions tables.KeyBackupVersionTable
keyBackupVersions keyBackupVersionStatements Devices tables.DevicesTable
keyBackups keyBackupStatements LoginTokens tables.LoginTokenTable
serverName gomatrixserverlib.ServerName LoginTokenLifetime time.Duration
bcryptCost int ServerName gomatrixserverlib.ServerName
openIDTokenLifetimeMS int64 BcryptCost int
OpenIDTokenLifetimeMS int64
accountsMu sync.Mutex
profilesMu sync.Mutex
accountDatasMu sync.Mutex
threepidsMu sync.Mutex
} }
// NewDatabase creates a new accounts and profiles database const (
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { // The length of generated device IDs
db, err := sqlutil.Open(dbProperties) deviceIDByteLength = 6
if err != nil { loginTokenByteLength = 32
return nil, err )
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
partitions := sqlutil.PartitionOffsetStatements{}
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword( func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string, ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) { ) (*api.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart) hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err return nil, err
} }
return d.accounts.selectAccountByLocalpart(ctx, localpart) return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
} }
// GetProfileByLocalpart returns the profile associated with the given localpart. // GetProfileByLocalpart returns the profile associated with the given localpart.
@ -133,7 +80,7 @@ func (d *Database) GetAccountByPassword(
func (d *Database) GetProfileByLocalpart( func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart) return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
} }
// SetAvatarURL updates the avatar URL of the profile associated with the given // SetAvatarURL updates the avatar URL of the profile associated with the given
@ -141,10 +88,8 @@ func (d *Database) GetProfileByLocalpart(
func (d *Database) SetAvatarURL( func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, localpart string, avatarURL string,
) error { ) error {
d.profilesMu.Lock() return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
defer d.profilesMu.Unlock() return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
}) })
} }
@ -153,10 +98,8 @@ func (d *Database) SetAvatarURL(
func (d *Database) SetDisplayName( func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, localpart string, displayName string,
) error { ) error {
d.profilesMu.Lock() return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
defer d.profilesMu.Unlock() return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
}) })
} }
@ -168,8 +111,8 @@ func (d *Database) SetPassword(
if err != nil { if err != nil {
return err return err
} }
return d.writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.accounts.updatePassword(ctx, localpart, hash) return d.Accounts.UpdatePassword(ctx, localpart, hash)
}) })
} }
@ -179,18 +122,11 @@ func (d *Database) SetPassword(
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'. err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
d.profilesMu.Lock()
d.accountDatasMu.Lock()
d.accountsMu.Lock()
defer d.profilesMu.Unlock()
defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part // For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest { if accountType == api.AccountTypeGuest {
var numLocalpart int64 var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
if err != nil { if err != nil {
return err return err
} }
@ -219,18 +155,18 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err return nil, err
} }
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.serverName) pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
prbs, err := json.Marshal(pushRuleSets) prbs, err := json.Marshal(pushRuleSets)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err return nil, err
} }
return account, nil return account, nil
@ -244,10 +180,8 @@ func (d *Database) createAccount(
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
d.accountDatasMu.Lock() return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
defer d.accountDatasMu.Unlock() return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
}) })
} }
@ -259,7 +193,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
rooms map[string]map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
return d.accountDatas.selectAccountData(ctx, localpart) return d.AccountDatas.SelectAccountData(ctx, localpart)
} }
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
@ -269,7 +203,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
func (d *Database) GetAccountDataByType( func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType( return d.AccountDatas.SelectAccountDataByType(
ctx, localpart, roomID, dataType, ctx, localpart, roomID, dataType,
) )
} }
@ -278,11 +212,11 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context,
) (int64, error) { ) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil) return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
} }
func (d *Database) hashPassword(plaintext string) (hash string, err error) { func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost)
return string(hashBytes), err return string(hashBytes), err
} }
@ -297,10 +231,8 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
func (d *Database) SaveThreePIDAssociation( func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string, ctx context.Context, threepid, localpart, medium string,
) (err error) { ) (err error) {
d.threepidsMu.Lock() return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
defer d.threepidsMu.Unlock() user, err := d.ThreePIDs.SelectLocalpartForThreePID(
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium, ctx, txn, threepid, medium,
) )
if err != nil { if err != nil {
@ -311,7 +243,7 @@ func (d *Database) SaveThreePIDAssociation(
return Err3PIDInUse return Err3PIDInUse
} }
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
}) })
} }
@ -322,10 +254,8 @@ func (d *Database) SaveThreePIDAssociation(
func (d *Database) RemoveThreePIDAssociation( func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string, ctx context.Context, threepid string, medium string,
) (err error) { ) (err error) {
d.threepidsMu.Lock() return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
defer d.threepidsMu.Unlock() return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium)
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
}) })
} }
@ -337,7 +267,7 @@ func (d *Database) RemoveThreePIDAssociation(
func (d *Database) GetLocalpartForThreePID( func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string, ctx context.Context, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
} }
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with // GetThreePIDsForLocalpart looks up the third-party identifiers associated with
@ -347,14 +277,14 @@ func (d *Database) GetLocalpartForThreePID(
func (d *Database) GetThreePIDsForLocalpart( func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
} }
// CheckAccountAvailability checks if the username/localpart is already present // CheckAccountAvailability checks if the username/localpart is already present
// in the database. // in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken. // If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart) _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return true, nil return true, nil
} }
@ -366,20 +296,20 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart) return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
} }
// SearchProfiles returns all profiles where the provided localpart or display name // SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database. // match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) { ) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit) return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit)
} }
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. // DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.accounts.deactivateAccount(ctx, localpart) return d.Accounts.DeactivateAccount(ctx, localpart)
}) })
} }
@ -388,9 +318,9 @@ func (d *Database) CreateOpenIDToken(
ctx context.Context, ctx context.Context,
token, localpart string, token, localpart string,
) (int64, error) { ) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
}) })
return expiresAtMS, err return expiresAtMS, err
} }
@ -400,14 +330,14 @@ func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context, ctx context.Context,
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token)
} }
func (d *Database) CreateKeyBackup( func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage, ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) { ) (version string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err return err
}) })
return return
@ -416,8 +346,8 @@ func (d *Database) CreateKeyBackup(
func (d *Database) UpdateKeyBackupAuthData( func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage, ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) { ) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData)
}) })
return return
} }
@ -425,8 +355,8 @@ func (d *Database) UpdateKeyBackupAuthData(
func (d *Database) DeleteKeyBackup( func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string, ctx context.Context, userID, version string,
) (exists bool, err error) { ) (exists bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version)
return err return err
}) })
return return
@ -435,8 +365,8 @@ func (d *Database) DeleteKeyBackup(
func (d *Database) GetKeyBackup( func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string, ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
return err return err
}) })
return return
@ -445,16 +375,16 @@ func (d *Database) GetKeyBackup(
func (d *Database) GetBackupKeys( func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string, ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) { ) (result map[string]map[string]api.KeyBackupSession, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if filterSessionID != "" { if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err return err
} }
if filterRoomID != "" { if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err return err
} }
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version)
return err return err
}) })
return return
@ -463,8 +393,8 @@ func (d *Database) GetBackupKeys(
func (d *Database) CountBackupKeys( func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string, ctx context.Context, version, userID string,
) (count int64, err error) { ) (count int64, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version) count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return err
} }
@ -478,8 +408,8 @@ func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) { ) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys // wrap the following logic in a txn to ensure we atomically upload keys
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) _, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return err
} }
@ -487,7 +417,7 @@ func (d *Database) UpsertBackupKeys(
return fmt.Errorf("backup was deleted") return fmt.Errorf("backup was deleted")
} }
// pull out all keys for this (user_id, version) // pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return err
} }
@ -501,10 +431,10 @@ func (d *Database) UpsertBackupKeys(
existingSession, ok := existingRoom[newKey.SessionID] existingSession, ok := existingRoom[newKey.SessionID]
if ok { if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey)
changed = true changed = true
if err != nil { if err != nil {
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err)
} }
} }
// if we shouldn't replace the key we do nothing with it // if we shouldn't replace the key we do nothing with it
@ -512,14 +442,14 @@ func (d *Database) UpsertBackupKeys(
} }
} }
// if we're here, either the room or session are new, either way, we insert // if we're here, either the room or session are new, either way, we insert
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey)
changed = true changed = true
if err != nil { if err != nil {
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err)
} }
} }
count, err = d.keyBackups.countKeys(ctx, txn, userID, version) count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return err
} }
@ -536,7 +466,7 @@ func (d *Database) UpsertBackupKeys(
newETag = strconv.FormatInt(oldETagInt+1, 10) newETag = strconv.FormatInt(oldETagInt+1, 10)
} }
etag = newETag etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag)
} else { } else {
etag = oldETag etag = oldETag
} }
@ -545,3 +475,196 @@ func (d *Database) UpsertBackupKeys(
}) })
return return
} }
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.Devices.SelectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.LoginTokenLifetime),
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.LoginTokens.DeleteLoginToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.LoginTokens.SelectLoginToken(ctx, token)
}

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -56,27 +57,29 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
s.db = db s := &accountDataStatements{
_, err = db.Exec(accountDataSchema) db: db,
if err != nil {
return
} }
return sqlutil.StatementList{ _, err := db.Exec(accountDataSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.insertAccountDataStmt, insertAccountDataSQL},
{&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL},
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return err return err
} }
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
return global, rooms, nil return global, rooms, nil
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -77,15 +78,16 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) execSchema(db *sql.DB) error { func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
s := &accountsStatements{
db: db,
serverName: serverName,
}
_, err := db.Exec(accountsSchema) _, err := db.Exec(accountsSchema)
return err if err != nil {
} return nil, err
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { return s, sqlutil.StatementList{
s.db = db
s.serverName = server
return sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL}, {&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL}, {&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL},
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
@ -122,28 +124,28 @@ func (s *accountsStatements) insertAccount(
}, nil }, nil
} }
func (s *accountsStatements) updatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart, passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return return
} }
func (s *accountsStatements) deactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return return
} }
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return return
} }
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
@ -167,7 +169,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil return &acc, nil
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt

View file

@ -5,13 +5,8 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) { func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" +
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
@ -98,52 +98,33 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) execSchema(db *sql.DB) error { func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
db: db,
serverName: serverName,
}
_, err := db.Exec(devicesSchema) _, err := db.Exec(devicesSchema)
return err if err != nil {
} return nil, err
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = writer
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
} }
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { return s, sqlutil.StatementList{
return {&s.insertDeviceStmt, insertDeviceSQL},
} {&s.selectDevicesCountStmt, selectDevicesCountSQL},
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
return {&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
} {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { {&s.updateDeviceNameStmt, updateDeviceNameSQL},
return {&s.deleteDeviceStmt, deleteDeviceSQL},
} {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { {&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
return {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
} }.Prepare(db)
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
} }
// insertDevice creates a new device. Returns an error if any device with the same access token already exists. // insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice( func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
@ -169,7 +150,7 @@ func (s *devicesStatements) insertDevice(
}, nil }, nil
} }
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -177,7 +158,7 @@ func (s *devicesStatements) deleteDevice(
return err return err
} }
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error { ) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
@ -195,7 +176,7 @@ func (s *devicesStatements) deleteDevices(
return err return err
} }
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -203,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err return err
} }
func (s *devicesStatements) updateDeviceName( func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -211,7 +192,7 @@ func (s *devicesStatements) updateDeviceName(
return err return err
} }
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string, ctx context.Context, accessToken string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -227,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user // selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID // localpart and deviceID
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -244,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err return &dev, err
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
@ -285,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil return devices, nil
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs)) iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs { for i := range deviceIDs {
@ -314,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000 lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupTableSchema = ` const keyBackupTableSchema = `
@ -72,12 +73,13 @@ type keyBackupStatements struct {
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
} }
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
_, err = db.Exec(keyBackupTableSchema) s := &keyBackupStatements{}
_, err := db.Exec(keyBackupTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL}, {&s.countKeysStmt, countKeysSQL},
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s keyBackupStatements) countKeys( func (s keyBackupStatements) CountKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) { ) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return return
} }
func (s *keyBackupStatements) insertBackupKey( func (s *keyBackupStatements) InsertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
return return
} }
func (s *keyBackupStatements) updateBackupKey( func (s *keyBackupStatements) UpdateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
return return
} }
func (s *keyBackupStatements) selectKeys( func (s *keyBackupStatements) SelectKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomID( func (s *keyBackupStatements) SelectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string, ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)

View file

@ -22,6 +22,7 @@ import (
"strconv" "strconv"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupVersionTableSchema = ` const keyBackupVersionTableSchema = `
@ -67,12 +68,13 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt updateKeyBackupETagStmt *sql.Stmt
} }
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
_, err = db.Exec(keyBackupVersionTableSchema) s := &keyBackupVersionStatements{}
_, err := db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.insertKeyBackupStmt, insertKeyBackupSQL},
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *keyBackupVersionStatements) insertKeyBackup( func (s *keyBackupVersionStatements) InsertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) { ) (version string, err error) {
var versionInt int64 var versionInt int64
@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return strconv.FormatInt(versionInt, 10), err return strconv.FormatInt(versionInt, 10), err
} }
func (s *keyBackupVersionStatements) updateKeyBackupAuthData( func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err return err
} }
func (s *keyBackupVersionStatements) updateKeyBackupETag( func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string, ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
return err return err
} }
func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) DeleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) { ) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
return ra == 1, nil return ra == 1, nil
} }
func (s *keyBackupVersionStatements) selectKeyBackup( func (s *keyBackupVersionStatements) SelectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64 var versionInt int64

View file

@ -21,18 +21,17 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type loginTokenStatements struct { type loginTokenStatements struct {
insertStmt *sql.Stmt insertStmt *sql.Stmt
deleteStmt *sql.Stmt deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt selectStmt *sql.Stmt
} }
// execSchema ensures tables and indices exist. const loginTokenSchema = `
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS login_tokens ( CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user -- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
@ -45,21 +44,32 @@ CREATE TABLE IF NOT EXISTS login_tokens (
-- This index allows efficient garbage collection of expired tokens. -- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`) `
return err
}
// prepare runs statement preparation. const insertLoginTokenSQL = "" +
func (s *loginTokenStatements) prepare(db *sql.DB) error { "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, const deleteLoginTokenSQL = "" +
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
s := &loginTokenStatements{}
_, err := db.Exec(loginTokenSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertStmt, insertLoginTokenSQL},
{&s.deleteStmt, deleteLoginTokenSQL},
{&s.selectStmt, selectLoginTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insert adds an already generated token to the database. // insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt) stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err return err
@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
// //
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient. // The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt) stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil { if err != nil {
@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
} }
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. // selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
); );
` `
const insertTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct { type openIDTokenStatements struct {
db *sql.DB db *sql.DB
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
s.db = db s := &openIDTokenStatements{
_, err = db.Exec(openIDTokenSchema) db: db,
if err != nil { serverName: serverName,
return err
} }
s.serverName = server _, err := db.Exec(openIDTokenSchema)
return sqlutil.StatementList{ if err != nil {
{&s.insertTokenStmt, insertTokenSQL}, return nil, err
{&s.selectTokenStmt, selectTokenSQL}, }
return s, sqlutil.StatementList{
{&s.insertTokenStmt, insertOpenIDTokenSQL},
{&s.selectTokenStmt, selectOpenIDTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insertToken inserts a new OpenID Connect token to the DB. // insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if the token already exists. // Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string,
@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken(
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found // Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributes( func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
ctx context.Context, ctx context.Context,
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const profilesSchema = ` const profilesSchema = `
@ -60,13 +61,15 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
s.db = db s := &profilesStatements{
_, err = db.Exec(profilesSchema) db: db,
if err != nil {
return
} }
return sqlutil.StatementList{ _, err := db.Exec(profilesSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertProfileStmt, insertProfileSQL}, {&s.insertProfileStmt, insertProfileSQL},
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
{&s.setAvatarURLStmt, setAvatarURLSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL},
@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return err return err
} }
func (s *profilesStatements) selectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile var profile authtypes.Profile
@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart(
return &profile, nil return &profile, nil
} }
func (s *profilesStatements) setAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL(
return return
} }
func (s *profilesStatements) setDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName(
return return
} }
func (s *profilesStatements) selectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(
ctx context.Context, searchString string, limit int, ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) { ) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile var profiles []authtypes.Profile

View file

@ -0,0 +1,106 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
if _, err = db.Exec(accountsSchema); err != nil {
// do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err
}
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
}
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
}
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
}
keyBackupTable, err := NewSQLiteKeyBackupTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err)
}
keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err)
}
loginTokenTable, err := NewSQLiteLoginTokenTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err)
}
openIDTable, err := NewSQLiteOpenIDTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err)
}
profilesTable, err := NewSQLiteProfilesTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err)
}
threePIDTable, err := NewSQLiteThreePIDTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
Devices: devicesTable,
KeyBackups: keyBackupTable,
KeyBackupVersions: keyBackupVersionTable,
LoginTokens: loginTokenTable,
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
LoginTokenLifetime: loginTokenLifetime,
BcryptCost: bcryptCost,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -60,13 +61,15 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt
} }
func (s *threepidStatements) prepare(db *sql.DB) (err error) { func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
s.db = db s := &threepidStatements{
_, err = db.Exec(threepidSchema) db: db,
if err != nil {
return
} }
return sqlutil.StatementList{ _, err := db.Exec(threepidSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
{&s.insertThreePIDStmt, insertThreePIDSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL},
@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *threepidStatements) selectLocalpartForThreePID( func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
return return
} }
func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
return threepids, rows.Err() return threepids, rows.Err()
} }
func (s *threepidStatements) insertThreePID( func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID(
return err return err
} }
func (s *threepidStatements) deleteThreePID( func (s *threepidStatements) DeleteThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium) _, err = stmt.ExecContext(ctx, threepid, medium)

View file

@ -15,26 +15,27 @@
//go:build !wasm //go:build !wasm
// +build !wasm // +build !wasm
package accounts package storage
import ( import (
"fmt" "fmt"
"time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" "github.com/matrix-org/dendrite/userapi/storage/postgres"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
) )
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters // and sets postgres connection parameters
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }

View file

@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package storage
import ( import (
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -27,10 +28,11 @@ func NewDatabase(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
bcryptCost int, bcryptCost int,
openIDTokenLifetimeMS int64, openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
) (Database, error) { ) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation") return nil, fmt.Errorf("can't use Postgres implementation")
default: default:

View file

@ -0,0 +1,95 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
)
type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
}
type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
type DevicesTable interface {
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error
}
type KeyBackupTable interface {
CountKeys(ctx context.Context, txn *sql.Tx, userID, version string) (count int64, err error)
InsertBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error)
UpdateBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error)
SelectKeys(ctx context.Context, txn *sql.Tx, userID, version string) (map[string]map[string]api.KeyBackupSession, error)
SelectKeysByRoomID(ctx context.Context, txn *sql.Tx, userID, version, roomID string) (map[string]map[string]api.KeyBackupSession, error)
SelectKeysByRoomIDAndSessionID(ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string) (map[string]map[string]api.KeyBackupSession, error)
}
type KeyBackupVersionTable interface {
InsertKeyBackup(ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string) (version string, err error)
UpdateKeyBackupAuthData(ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage) error
UpdateKeyBackupETag(ctx context.Context, txn *sql.Tx, userID, version, etag string) error
DeleteKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (bool, error)
SelectKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
}
type LoginTokenTable interface {
InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error
DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error
SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}
type OpenIDTable interface {
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
}
type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}
type ThreePIDTable interface {
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}

View file

@ -23,18 +23,10 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// defaultLoginTokenLifetime determines how old a valid token may be.
//
// NOTSPEC: The current spec says "SHOULD be limited to around five
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
const defaultLoginTokenLifetime = 2 * time.Minute
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
// on the given input API. // on the given input API.
func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
@ -44,26 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI { ) api.UserInternalAPI {
deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db") logrus.WithError(err).Panicf("failed to connect to device db")
} }
return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) return newInternalAPI(db, cfg, appServices, keyAPI)
} }
func newInternalAPI( func newInternalAPI(
accountDB accounts.Database, db storage.Database,
deviceDB devices.Database,
cfg *config.UserAPI, cfg *config.UserAPI,
appServices []config.ApplicationService, appServices []config.ApplicationService,
keyAPI keyapi.KeyInternalAPI, keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI { ) api.UserInternalAPI {
return &internal.UserInternalAPI{ return &internal.UserInternalAPI{
AccountDB: accountDB, DB: db,
DeviceDB: deviceDB,
ServerName: cfg.Matrix.ServerName, ServerName: cfg.Matrix.ServerName,
AppServices: appServices, AppServices: appServices,
KeyAPI: keyAPI, KeyAPI: keyAPI,

View file

@ -31,8 +31,7 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/devices"
) )
const ( const (
@ -43,23 +42,19 @@ type apiTestOpts struct {
loginTokenLifetime time.Duration loginTokenLifetime time.Duration
} }
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) { func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) {
if opts.loginTokenLifetime == 0 { if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = defaultLoginTokenLifetime opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
} }
dbopts := &config.DatabaseOptions{ dbopts := &config.DatabaseOptions{
ConnectionString: "file::memory:", ConnectionString: "file::memory:",
MaxOpenConnections: 1, MaxOpenConnections: 1,
MaxIdleConnections: 1, MaxIdleConnections: 1,
} }
accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime)
if err != nil { if err != nil {
t.Fatalf("failed to create account DB: %s", err) t.Fatalf("failed to create account DB: %s", err)
} }
deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime)
if err != nil {
t.Fatalf("failed to create device DB: %s", err)
}
cfg := &config.UserAPI{ cfg := &config.UserAPI{
Matrix: &config.Global{ Matrix: &config.Global{
@ -67,7 +62,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, a
}, },
} }
return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB return newInternalAPI(accountDB, cfg, nil, nil), accountDB
} }
func TestQueryProfile(t *testing.T) { func TestQueryProfile(t *testing.T) {