// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package internal import ( "context" "database/sql" "encoding/json" "errors" "fmt" "strconv" "time" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" synctypes "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" userapiUtil "github.com/matrix-org/dendrite/userapi/util" ) type UserInternalAPI struct { DB storage.UserDatabase KeyDatabase storage.KeyDatabase SyncProducer *producers.SyncAPI KeyChangeProducer *producers.KeyChange Config *config.UserAPI DisableTLSValidation bool // AppServices is the list of all registered AS AppServices []config.ApplicationService RSAPI rsapi.UserRoomserverAPI PgClient pushgateway.Client FedClient fedsenderapi.KeyserverFederationAPI Updater *DeviceListUpdater } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return err } if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot update account data of remote users (server name %s)", domain) } if req.DataType == "" { return fmt.Errorf("data type must not be empty") } if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil { util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed") return fmt.Errorf("failed to save account data: %w", err) } var ignoredUsers *synctypes.IgnoredUsers if req.DataType == "m.ignored_user_list" { ignoredUsers = &synctypes.IgnoredUsers{} _ = json.Unmarshal(req.AccountData, ignoredUsers) } if req.DataType == "m.fully_read" { if err := a.setFullyRead(ctx, req); err != nil { return err } } if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ RoomID: req.RoomID, Type: req.DataType, IgnoredUsers: ignoredUsers, }); err != nil { util.GetLogger(ctx).WithError(err).Error("a.SyncProducer.SendAccountData failed") return fmt.Errorf("failed to send account data to output: %w", err) } return nil } func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error { var output eventutil.ReadMarkerJSON if err := json.Unmarshal(req.AccountData, &output); err != nil { return err } localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") return nil } if !a.Config.Matrix.IsLocalServerName(domain) { return nil } deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) if err != nil { logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") return err } if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil { logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed") return err } // nothing changed, no need to notify the push gateway if !deleted { return nil } if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil { logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") return err } return nil } func postRegisterJoinRooms(cfg *config.UserAPI, acc *api.Account, rsAPI rsapi.UserRoomserverAPI) { // POST register behaviour: check if the user is a normal user. // If the user is a normal user, add user to room specified in the configuration "auto_join_rooms". if acc.AccountType != api.AccountTypeAppService && acc.AppServiceID == "" { for room := range cfg.AutoJoinRooms { userID := userutil.MakeUserID(acc.Localpart, cfg.Matrix.ServerName) err := addUserToRoom(context.Background(), rsAPI, cfg.AutoJoinRooms[room], acc.Localpart, userID) if err != nil { logrus.WithFields(logrus.Fields{ "user_id": userID, "room": cfg.AutoJoinRooms[room], }).WithError(err).Errorf("user failed to auto-join room") } } } } // Add user to a room. This function currently working for auto_join_rooms config, // which can add a newly registered user to a specified room. func addUserToRoom( ctx context.Context, rsAPI rsapi.UserRoomserverAPI, roomID string, username string, userID string, ) error { addGroupContent := make(map[string]interface{}) // This make sure the user's username can be displayed correctly. // Because the newly-registered user doesn't have an avatar, the avatar_url is not needed. addGroupContent["displayname"] = username joinReq := rsapi.PerformJoinRequest{ RoomIDOrAlias: roomID, UserID: userID, Content: addGroupContent, } joinRes := rsapi.PerformJoinResponse{} return rsAPI.PerformJoin(ctx, &joinReq, &joinRes) } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { serverName := req.ServerName if serverName == "" { serverName = a.Config.Matrix.ServerName } if !a.Config.Matrix.IsLocalServerName(serverName) { return fmt.Errorf("server name %s is not local", serverName) } acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { case api.ConflictUpdate: break case api.ConflictAbort: return &api.ErrorConflict{ Message: err.Error(), } } } // account already exists res.AccountCreated = false res.Account = &api.Account{ AppServiceID: req.AppServiceID, Localpart: req.Localpart, ServerName: serverName, UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), AccountType: req.AccountType, } return nil } // Inform the SyncAPI about the newly created push_rules if err = a.SyncProducer.SendAccountData(acc.UserID, eventutil.AccountData{ Type: "m.push_rules", }); err != nil { util.GetLogger(ctx).WithFields(logrus.Fields{ "user_id": acc.UserID, }).WithError(err).Warn("failed to send account data to the SyncAPI") } if req.AccountType == api.AccountTypeGuest { res.AccountCreated = true res.Account = acc return nil } if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { return fmt.Errorf("a.DB.SetDisplayName: %w", err) } postRegisterJoinRooms(a.Config, acc, a.RSAPI) res.AccountCreated = true res.Account = acc return nil } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { if !a.Config.Matrix.IsLocalServerName(req.ServerName) { return fmt.Errorf("server name %s is not local", req.ServerName) } if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil { return err } if req.LogoutDevices { if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil { return err } } res.PasswordUpdated = true return nil } func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { serverName := req.ServerName if serverName == "" { serverName = a.Config.Matrix.ServerName } if !a.Config.Matrix.IsLocalServerName(serverName) { return fmt.Errorf("server name %s is not local", serverName) } // If a device ID was specified, check if it already exists and // avoid sending an empty device list update which would remove // existing device keys. isExisting := false if req.DeviceID != nil && *req.DeviceID != "" { existingDev, err := a.DB.GetDeviceByID(ctx, req.Localpart, req.ServerName, *req.DeviceID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } isExisting = existingDev.ID == *req.DeviceID } util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } res.DeviceCreated = true res.Device = dev if req.NoDeviceListUpdate || isExisting { return nil } // create empty device keys and upload them to trigger device list changes return a.deviceListUpdate(dev.UserID, []string{dev.ID}) } func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error { util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("devices", req.DeviceIDs).Info("PerformDeviceDeletion") local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return err } if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain) } deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs) } if err != nil { return err } // Ask the keyserver to delete device keys and signatures for those devices deleteReq := &api.PerformDeleteKeysRequest{ UserID: req.UserID, } for _, keyID := range req.DeviceIDs { deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) } deleteRes := &api.PerformDeleteKeysResponse{} if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { return err } if err := deleteRes.Error; err != nil { return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err) } // create empty device keys and upload them to delete what was once there and trigger device list changes return a.deviceListUpdate(req.UserID, deletedDeviceIDs) } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { deviceKeys := make([]api.DeviceKeys, len(deviceIDs)) for i, did := range deviceIDs { deviceKeys[i] = api.DeviceKeys{ UserID: userID, DeviceID: did, KeyJSON: nil, } } var uploadRes api.PerformUploadKeysResponse if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: userID, DeviceKeys: deviceKeys, }, &uploadRes); err != nil { return err } if uploadRes.Error != nil { return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error) } if len(uploadRes.KeyErrors) > 0 { return fmt.Errorf("failed to delete device keys, key errors: %+v", uploadRes.KeyErrors) } return nil } func (a *UserInternalAPI) PerformLastSeenUpdate( ctx context.Context, req *api.PerformLastSeenUpdateRequest, res *api.PerformLastSeenUpdateResponse, ) error { localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("server name %s is not local", domain) } if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil } func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("server name %s is not local", domain) } dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil } else if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") return err } res.DeviceExists = true err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err } if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { // display name has changed: update the device key var uploadRes api.PerformUploadKeysResponse if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: req.RequestingUserID, DeviceKeys: []api.DeviceKeys{ { DeviceID: dev.ID, DisplayName: *req.DisplayName, KeyJSON: nil, UserID: dev.UserID, }, }, OnlyDisplayNameUpdates: true, }, &uploadRes); err != nil { return err } if uploadRes.Error != nil { return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error) } if len(uploadRes.KeyErrors) > 0 { return fmt.Errorf("failed to update device key display name, key errors: %+v", uploadRes.KeyErrors) } } return nil } var ( ErrIsRemoteServer = errors.New("cannot query profile of remote users") ) func (a *UserInternalAPI) QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error) { local, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return nil, err } if !a.Config.Matrix.IsLocalServerName(domain) { return nil, ErrIsRemoteServer } prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) if err != nil { if err == sql.ErrNoRows { return nil, appserviceAPI.ErrProfileNotExists } return nil, err } return prof, nil } func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit) if err != nil { return err } res.Profiles = profiles return nil } func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs) if err != nil { return err } res.DeviceInfo = make(map[string]struct { DisplayName string UserID string }) for _, d := range devices { res.DeviceInfo[d.ID] = struct { DisplayName string UserID string }{ DisplayName: d.DisplayName, UserID: d.UserID, } } return nil } func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return err } if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain) if err != nil { return err } res.UserExists = true res.Devices = devs return nil } func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return err } if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query account data of remote users (server name %s)", domain) } if req.DataType != "" { var data json.RawMessage data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType) if err != nil { return err } res.RoomAccountData = make(map[string]map[string]json.RawMessage) res.GlobalAccountData = make(map[string]json.RawMessage) if data != nil { if req.RoomID != "" { if _, ok := res.RoomAccountData[req.RoomID]; !ok { res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage) } res.RoomAccountData[req.RoomID][req.DataType] = data } else { res.GlobalAccountData[req.DataType] = data } } return nil } global, rooms, err := a.DB.GetAccountData(ctx, local, domain) if err != nil { return err } res.RoomAccountData = rooms res.GlobalAccountData = global return nil } func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAccessTokenRequest, res *api.QueryAccessTokenResponse) error { if req.AppServiceUserID != "" { appServiceDevice, err := a.queryAppServiceToken(ctx, req.AccessToken, req.AppServiceUserID) if err != nil { res.Err = err.Error() } res.Device = appServiceDevice return nil } device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken) if err != nil { if err == sql.ErrNoRows { return nil } return err } localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { return err } if !a.Config.Matrix.IsLocalServerName(domain) { return nil } acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain) if err != nil { return err } device.AccountType = acc.AccountType res.Device = device return nil } func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) return } // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { // Search for app service with given access_token var appService *config.ApplicationService for _, as := range a.AppServices { if as.ASToken == token { appService = &as break } } if appService == nil { return nil, nil } // Create a dummy device for AS user dev := api.Device{ // Use AS dummy device ID ID: "AS_Device", // AS dummy device has AS's token. AccessToken: token, AppserviceID: appService.ID, AccountType: api.AccountTypeAppService, } localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) if err != nil { return nil, err } if localpart != "" { // AS is masquerading as another user // Verify that the user is registered account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain) // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { // Set the userID of dummy device dev.UserID = appServiceUserID return &dev, nil } return nil, &api.ErrorForbidden{Message: "appservice has not registered this user"} } // AS is not masquerading as any user, so use AS's sender_localpart dev.UserID = appService.SenderLocalpart return &dev, nil } // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { serverName := req.ServerName if serverName == "" { serverName = a.Config.Matrix.ServerName } if !a.Config.Matrix.IsLocalServerName(serverName) { return fmt.Errorf("server name %q not locally configured", serverName) } evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { return err } if err := evacuateRes.Error; err != nil { logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation") } deviceReq := &api.PerformDeviceDeletionRequest{ UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } deviceRes := &api.PerformDeviceDeletionResponse{} if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { return err } pusherReq := &api.PerformPusherDeletionRequest{ Localpart: req.Localpart, } if err := a.PerformPusherDeletion(ctx, pusherReq, &struct{}{}); err != nil { return err } err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName) res.AccountDeactivated = err == nil return err } // PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { token := util.RandomString(24) exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID) res.Token = api.OpenIDToken{ Token: token, UserID: req.UserID, ExpiresAtMS: exp, } return err } // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token) if err != nil { return err } res.Sub = openIDTokenAttrs.UserID res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS return nil } func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { // Delete metadata if req.DeleteBackup { if req.Version == "" { res.BadInput = true res.Error = "must specify a version to delete" return nil } exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) if err != nil { res.Error = fmt.Sprintf("failed to delete backup: %s", err) } res.Exists = exists res.Version = req.Version return nil } // Create metadata if req.Version == "" { version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to create backup: %s", err) } res.Exists = err == nil res.Version = version return nil } // Update metadata if len(req.Keys.Rooms) == 0 { err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to update backup: %s", err) } res.Exists = err == nil res.Version = req.Version return nil } // Upload Keys for a specific version metadata a.uploadBackupKeys(ctx, req, res) return nil } func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { // you can only upload keys for the CURRENT version version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") if err != nil { res.Error = fmt.Sprintf("failed to query version: %s", err) return } if deleted { res.Error = "backup was deleted" return } if version != req.Version { res.BadInput = true res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version) return } res.Exists = true res.Version = version // map keys to a form we can upload more easily - the map ensures we have no duplicates. var uploads []api.InternalKeyBackupSession for roomID, data := range req.Keys.Rooms { for sessionID, sessionData := range data.Sessions { uploads = append(uploads, api.InternalKeyBackupSession{ RoomID: roomID, SessionID: sessionID, KeyBackupSession: sessionData, }) } } count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) if err != nil { res.Error = fmt.Sprintf("failed to upsert keys: %s", err) return } res.KeyCount = count res.KeyETag = etag } func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error { version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { if err == sql.ErrNoRows { res.Exists = false return nil } res.Error = fmt.Sprintf("failed to query key backup: %s", err) return nil } res.Algorithm = algorithm res.AuthData = authData res.ETag = etag res.Exists = !deleted if !req.ReturnKeys { res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) if err != nil { res.Error = fmt.Sprintf("failed to count keys: %s", err) } return nil } result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { res.Error = fmt.Sprintf("failed to query keys: %s", err) return nil } res.Keys = result return nil } func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { if req.Limit == 0 || req.Limit > 1000 { req.Limit = 1000 } var fromID int64 var err error if req.From != "" { fromID, err = strconv.ParseInt(req.From, 10, 64) if err != nil { return fmt.Errorf("QueryNotifications: parsing 'from': %w", err) } } var filter tables.NotificationFilter = tables.AllNotifications if req.Only == "highlight" { filter = tables.HighlightNotifications } notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter) if err != nil { return err } if notifs == nil { // This ensures empty is JSON-encoded as [] instead of null. notifs = []*api.Notification{} } res.Notifications = notifs if lastID >= 0 { res.NextToken = strconv.FormatInt(lastID+1, 10) } return nil } func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error { util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "pushkey": req.Pusher.PushKey, "display_name": req.Pusher.AppDisplayName, }).Info("PerformPusherCreation") if !req.Append { err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey) if err != nil { return err } } if req.Pusher.Kind == "" { return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName) } if req.Pusher.PushKeyTS == 0 { req.Pusher.PushKeyTS = int64(time.Now().Unix()) } return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName) } func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName) if err != nil { return err } for i := range pushers { logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) if pushers[i].SessionID != req.SessionID { err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName) if err != nil { return err } } } return nil } func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { var err error res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName) return err } func (a *UserInternalAPI) PerformPushRulesPut( ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, ) error { bs, err := json.Marshal(ruleSets) if err != nil { return err } userReq := api.InputAccountDataRequest{ UserID: userID, DataType: pushRulesAccountDataType, AccountData: json.RawMessage(bs), } var userRes api.InputAccountDataResponse // empty return a.InputAccountData(ctx, &userReq, &userRes) } func (a *UserInternalAPI) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) { localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return nil, fmt.Errorf("failed to split user ID %q for push rules", userID) } return a.DB.QueryPushRules(ctx, localpart, domain) } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) { return a.DB.SetAvatarURL(ctx, localpart, serverName, avatarURL) } func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error { id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName) if err != nil { return err } res.ID = id return nil } func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { var err error res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName) return err } func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword) switch err { case sql.ErrNoRows: // user does not exist return nil case bcrypt.ErrMismatchedHashAndPassword: // user exists, but password doesn't match return nil case bcrypt.ErrHashTooShort: // user exists, but probably a passwordless account return nil default: res.Exists = true res.Account = acc return nil } } func (a *UserInternalAPI) SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) { return a.DB.SetDisplayName(ctx, localpart, serverName, displayName) } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) if err != nil { return err } res.Localpart = localpart res.ServerName = domain return nil } func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName) if err != nil { return err } res.ThreePIDs = r return nil } func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error { return a.DB.RemoveThreePIDAssociation(ctx, req.ThreePID, req.Medium) } func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) } const pushRulesAccountDataType = "m.push_rules"