diff --git a/appservice/api/query.go b/appservice/api/query.go index cd74d866c..e53ad4259 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -23,7 +23,7 @@ import ( "errors" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -85,7 +85,7 @@ func RetrieveUserProfile( ctx context.Context, userID string, asAPI AppServiceQueryAPI, - accountDB accounts.Database, + accountDB userdb.Database, ) (*authtypes.Profile, error) { localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/appservice/appservice.go b/appservice/appservice.go index 7e7c67f53..b33d7b701 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -22,6 +22,8 @@ import ( "time" "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/inthttp" @@ -34,7 +36,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/sirupsen/logrus" ) // AddInternalRoutes registers HTTP handlers for internal API calls @@ -121,7 +122,7 @@ func generateAppServiceAccount( ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ - AccountType: userapi.AccountTypeUser, + AccountType: userapi.AccountTypeAppService, Localpart: as.SenderLocalpart, AppServiceID: as.ID, OnConflict: userapi.ConflictUpdate, diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 4662c4ab0..ed11ee749 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -283,8 +283,7 @@ func (m *DendriteMonolith) Start() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix)) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix)) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index f568bad4d..cac4512cb 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index a6b811300..48c2d531e 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -28,7 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -38,7 +38,7 @@ func AddPublicRoutes( synapseAdminRouter *mux.Router, consentAPIMux *mux.Router, cfg *config.ClientAPI, - accountsDB accounts.Database, + accountsDB userdb.Database, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, eduInputAPI eduServerAPI.EDUServerInputAPI, diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index 89c2d59b1..1965fe3f8 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -156,6 +156,15 @@ func MissingParam(msg string) *MatrixError { return &MatrixError{"M_MISSING_PARAM", msg} } +// LeaveServerNoticeError is an error returned when trying to reject an invite +// for a server notice room. +func LeaveServerNoticeError() *MatrixError { + return &MatrixError{ + ErrCode: "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM", + Err: "You cannot reject this invite", + } +} + type IncompatibleRoomVersionError struct { RoomVersion string `json:"room_version"` Error string `json:"error"` diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go index b448791c3..87bb79366 100644 --- a/clientapi/routing/admin_whois.go +++ b/clientapi/routing/admin_whois.go @@ -47,8 +47,8 @@ func GetAdminWhois( req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, ) util.JSONResponse { - if userID != device.UserID { - // TODO: Still allow if user is admin + allowed := device.AccountType == api.AccountTypeAdmin || userID == device.UserID + if !allowed { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("userID does not match the current user"), diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index e89d8ff24..fcacc76c0 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -15,6 +15,7 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" @@ -30,7 +31,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -137,36 +138,17 @@ type fledglingEvent struct { func CreateRoom( req *http.Request, device *api.Device, cfg *config.ClientAPI, - accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { - // TODO (#267): Check room ID doesn't clash with an existing one, and we - // probably shouldn't be using pseudo-random strings, maybe GUIDs? - roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) - return createRoom(req, device, cfg, roomID, accountDB, rsAPI, asAPI) -} - -// createRoom implements /createRoom -// nolint: gocyclo -func createRoom( - req *http.Request, device *api.Device, - cfg *config.ClientAPI, roomID string, - accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, -) util.JSONResponse { - logger := util.GetLogger(req.Context()) - userID := device.UserID var r createRoomRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr } - // TODO: apply rate-limit - if resErr = r.Validate(); resErr != nil { return *resErr } - evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -174,6 +156,25 @@ func createRoom( JSON: jsonerror.InvalidArgumentValue(err.Error()), } } + return createRoom(req.Context(), r, device, cfg, accountDB, rsAPI, asAPI, evTime) +} + +// createRoom implements /createRoom +// nolint: gocyclo +func createRoom( + ctx context.Context, + r createRoomRequest, device *api.Device, + cfg *config.ClientAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + asAPI appserviceAPI.AppServiceQueryAPI, + evTime time.Time, +) util.JSONResponse { + // TODO (#267): Check room ID doesn't clash with an existing one, and we + // probably shouldn't be using pseudo-random strings, maybe GUIDs? + roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) + + logger := util.GetLogger(ctx) + userID := device.UserID // Clobber keys: creator, room_version @@ -200,16 +201,16 @@ func createRoom( "roomVersion": roomVersion, }).Info("Creating new room") - profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB) + profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") + util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") return jsonerror.InternalServerError() } createContent := map[string]interface{}{} if len(r.CreationContent) > 0 { if err = json.Unmarshal(r.CreationContent, &createContent); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for creation_content failed") + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("invalid create content"), @@ -230,7 +231,7 @@ func createRoom( // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for power_level_content_override failed") + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("malformed power_level_content_override"), @@ -319,9 +320,9 @@ func createRoom( } var aliasResp roomserverAPI.GetRoomIDForAliasResponse - err = rsAPI.GetRoomIDForAlias(req.Context(), &hasAliasReq, &aliasResp) + err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") + util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") return jsonerror.InternalServerError() } if aliasResp.RoomID != "" { @@ -426,7 +427,7 @@ func createRoom( } err = builder.SetContent(e.Content) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") + util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") return jsonerror.InternalServerError() } if i > 0 { @@ -435,12 +436,12 @@ func createRoom( var ev *gomatrixserverlib.Event ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildEvent failed") + util.GetLogger(ctx).WithError(err).Error("buildEvent failed") return jsonerror.InternalServerError() } if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed") + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return jsonerror.InternalServerError() } @@ -448,7 +449,7 @@ func createRoom( builtEvents = append(builtEvents, ev.Headered(roomVersion)) err = authEvents.AddEvent(ev) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed") + util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") return jsonerror.InternalServerError() } } @@ -462,8 +463,8 @@ func createRoom( SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } - if err = roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, inputs, false); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return jsonerror.InternalServerError() } @@ -478,9 +479,9 @@ func createRoom( } var aliasResp roomserverAPI.SetRoomAliasResponse - err = rsAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) + err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") + util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") return jsonerror.InternalServerError() } @@ -519,11 +520,11 @@ func createRoom( for _, invitee := range r.Invite { // Build the invite event. inviteEvent, err := buildMembershipEvent( - req.Context(), invitee, "", accountDB, device, gomatrixserverlib.Invite, + ctx, invitee, "", accountDB, device, gomatrixserverlib.Invite, roomID, true, cfg, evTime, rsAPI, asAPI, ) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") continue } inviteStrippedState := append( @@ -532,7 +533,7 @@ func createRoom( ) // Send the invite event to the roomserver. err = roomserverAPI.SendInvite( - req.Context(), + ctx, rsAPI, inviteEvent.Headered(roomVersion), inviteStrippedState, // invite room state @@ -544,7 +545,7 @@ func createRoom( return e.JSONResponse() case nil: default: - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), @@ -556,13 +557,13 @@ func createRoom( if r.Visibility == "public" { // expose this room in the published room list var pubRes roomserverAPI.PerformPublishResponse - rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ + rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: "public", }, &pubRes) if pubRes.Error != nil { // treat as non-fatal since the room is already made by this point - util.GetLogger(req.Context()).WithError(pubRes.Error).Error("failed to visibility:public") + util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public") } } diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 578aaec56..d30a87a57 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomIDOrAlias string, ) util.JSONResponse { // Prepare to ask the roomserver to perform the room join. diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7b9d8acd2..7ecab9d4e 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/util" ) @@ -36,7 +36,7 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, keyserverAPI api.KeyInternalAPI, device *userapi.Device, - accountDB accounts.Database, cfg *config.ClientAPI, + accountDB userdb.Database, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index 38cef118e..a34dd02d3 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -38,6 +38,12 @@ func LeaveRoomByID( // Ask the roomserver to perform the leave. if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil { + if leaveRes.Code != 0 { + return util.JSONResponse{ + Code: leaveRes.Code, + JSON: jsonerror.LeaveServerNoticeError(), + } + } return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown(err.Error()), diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index b48b9e93b..ec5c998be 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -54,7 +54,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, + req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI, cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 58f187608..ffe8da136 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -30,7 +30,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -39,7 +39,7 @@ import ( var errMissingUserID = errors.New("'user_id' must be supplied") func SendBan( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -81,7 +81,7 @@ func SendBan( return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) } -func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device, +func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomVer gomatrixserverlib.RoomVersion, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { @@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us } func SendKick( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -165,7 +165,7 @@ func SendKick( } func SendUnban( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -200,7 +200,7 @@ func SendUnban( } func SendInvite( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -226,27 +226,42 @@ func SendInvite( } } + // We already received the return value, so no need to check for an error here. + response, _ := sendInvite(req.Context(), accountDB, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime) + return response +} + +// sendInvite sends an invitation to a user. Returns a JSONResponse and an error +func sendInvite( + ctx context.Context, + accountDB userdb.Database, + device *userapi.Device, + roomID, userID, reason string, + cfg *config.ClientAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, + asAPI appserviceAPI.AppServiceQueryAPI, evTime time.Time, +) (util.JSONResponse, error) { event, err := buildMembershipEvent( - req.Context(), body.UserID, body.Reason, accountDB, device, "invite", + ctx, userID, reason, accountDB, device, "invite", roomID, false, cfg, evTime, rsAPI, asAPI, ) if err == errMissingUserID { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error()), - } + }, err } else if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound(err.Error()), - } + }, err } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") + return jsonerror.InternalServerError(), err } err = roomserverAPI.SendInvite( - req.Context(), rsAPI, + ctx, rsAPI, event, nil, // ask the roomserver to draw up invite room state for us cfg.Matrix.ServerName, @@ -254,24 +269,24 @@ func SendInvite( ) switch e := err.(type) { case *roomserverAPI.PerformError: - return e.JSONResponse() + return e.JSONResponse(), err case nil: return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, - } + }, nil default: - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), - } + }, err } } func buildMembershipEvent( ctx context.Context, - targetUserID, reason string, accountDB accounts.Database, + targetUserID, reason string, accountDB userdb.Database, device *userapi.Device, membership, roomID string, isDirect bool, cfg *config.ClientAPI, evTime time.Time, @@ -312,7 +327,7 @@ func loadProfile( ctx context.Context, userID string, cfg *config.ClientAPI, - accountDB accounts.Database, + accountDB userdb.Database, asAPI appserviceAPI.AppServiceQueryAPI, ) (*authtypes.Profile, error) { _, serverName, err := gomatrixserverlib.SplitID('@', userID) @@ -366,7 +381,7 @@ func checkAndProcessThreepid( body *threepid.MembershipRequest, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomID string, evTime time.Time, ) (inviteStored bool, errRes *util.JSONResponse) { diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index b24424430..499510193 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -9,7 +9,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -29,7 +29,7 @@ type newPasswordAuth struct { func Password( req *http.Request, userAPI api.UserInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 26aa64ce1..8f89e97f4 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -19,7 +19,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomIDOrAlias string, ) util.JSONResponse { // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to @@ -82,7 +82,7 @@ func UnpeekRoomByID( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomID string, ) util.JSONResponse { unpeekReq := roomserverAPI.PerformUnpeekRequest{ diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 017facd20..717cbda75 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" @@ -36,7 +36,7 @@ import ( // GetProfile implements GET /profile/{userID} func GetProfile( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, @@ -65,7 +65,7 @@ func GetProfile( // GetAvatarURL implements GET /profile/{userID}/avatar_url func GetAvatarURL( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -92,7 +92,7 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url func SetAvatarURL( - req *http.Request, accountDB accounts.Database, + req *http.Request, accountDB userdb.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -182,7 +182,7 @@ func SetAvatarURL( // GetDisplayName implements GET /profile/{userID}/displayname func GetDisplayName( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -209,7 +209,7 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, accountDB accounts.Database, + req *http.Request, accountDB userdb.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -302,7 +302,7 @@ func SetDisplayName( // Returns an error when something goes wrong or specifically // eventutil.ErrProfileNoExists when the profile doesn't exist. func getProfile( - ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI, + ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 496c05b39..f517ddf27 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -32,18 +32,19 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/tokens" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/tokens" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" + userdb "github.com/matrix-org/dendrite/userapi/storage" ) var ( @@ -153,7 +154,7 @@ type authDict struct { // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api type userInteractiveResponse struct { Flows []authtypes.Flow `json:"flows"` - Completed []authtypes.LoginType `json:"completed,omitempty"` + Completed []authtypes.LoginType `json:"completed"` Params map[string]interface{} `json:"params"` Session string `json:"session"` } @@ -447,7 +448,7 @@ func validateApplicationService( func Register( req *http.Request, userAPI userapi.UserInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, cfg *config.ClientAPI, ) util.JSONResponse { var r registerRequest @@ -531,6 +532,13 @@ func handleGuestRegistration( cfg *config.ClientAPI, userAPI userapi.UserInternalAPI, ) util.JSONResponse { + if cfg.RegistrationDisabled || cfg.GuestsDisabled { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Guest registration is disabled"), + } + } + var res userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeGuest, @@ -708,7 +716,7 @@ func handleApplicationServiceRegistration( // application service registration is entirely separate. return completeRegistration( req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), policyVersion, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) } @@ -732,7 +740,7 @@ func checkAndCompleteFlow( return completeRegistration( req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), policyVersion, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) } @@ -757,6 +765,7 @@ func completeRegistration( username, password, appserviceID, ipAddr, userAgent, policyVersion string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, + accType userapi.AccountType, ) util.JSONResponse { if username == "" { return util.JSONResponse{ @@ -771,13 +780,12 @@ func completeRegistration( JSON: jsonerror.BadJSON("missing password"), } } - var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AppServiceID: appserviceID, Localpart: username, Password: password, - AccountType: userapi.AccountTypeUser, + AccountType: accType, OnConflict: userapi.ConflictAbort, PolicyVersion: policyVersion, }, &accRes) @@ -904,7 +912,7 @@ type availableResponse struct { func RegisterAvailable( req *http.Request, cfg *config.ClientAPI, - accountDB accounts.Database, + accountDB userdb.Database, ) util.JSONResponse { username := req.URL.Query().Get("username") @@ -976,5 +984,10 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS return *resErr } deviceID := "shared_secret_registration" - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID) + + accType := userapi.AccountTypeUser + if ssrr.Admin { + accType = userapi.AccountTypeAdmin + } + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 6eae02a3d..230053f60 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -15,6 +15,7 @@ package routing import ( + "context" "encoding/json" "net/http" "strings" @@ -34,7 +35,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -51,7 +52,7 @@ func Setup( eduAPI eduServerAPI.EDUServerInputAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, - accountDB accounts.Database, + accountDB userdb.Database, userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, syncProducer *producers.SyncAPIProducer, @@ -117,6 +118,58 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } + // server notifications + if cfg.Matrix.ServerNotices.Enabled { + logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") + serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, accountDB, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + + synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", + httputil.MakeAuthAPI("send_server_notice", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { + // not specced, but ensure we're rate limiting requests to this endpoint + if r := rateLimits.Limit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendServerNotice( + req, &cfg.Matrix.ServerNotices, + cfg, userAPI, rsAPI, accountDB, asAPI, + device, serverNotificationSender, + &txnID, transactionsCache, + ) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + synapseAdminRouter.Handle("/admin/v1/send_server_notice", + httputil.MakeAuthAPI("send_server_notice", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { + // not specced, but ensure we're rate limiting requests to this endpoint + if r := rateLimits.Limit(req); r != nil { + return *r + } + return SendServerNotice( + req, &cfg.Matrix.ServerNotices, + cfg, userAPI, rsAPI, accountDB, asAPI, + device, serverNotificationSender, + nil, transactionsCache, + ) + }), + ).Methods(http.MethodPost, http.MethodOptions) + } + + // You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables. + // So make a named path variable called 'apiversion' (which we will never read in handlers) and then do + // (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group + // from a match (gorilla/mux exposes no way to do this) so it demands you make it a non-capturing group + // using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching. + // Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing! + v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() + // unspecced consent tracking if cfg.Matrix.UserConsentOptions.Enabled { consentAPIMux.Handle("/consent", @@ -129,12 +182,12 @@ func Setup( r0mux := publicAPIMux.PathPrefix("/r0").Subrouter() unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() - r0mux.Handle("/createRoom", + v3mux.Handle("/createRoom", httputil.MakeAuthAPI("createRoom", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/join/{roomIDOrAlias}", + v3mux.Handle("/join/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -150,7 +203,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) if mscCfg.Enabled("msc2753") { - r0mux.Handle("/peek/{roomIDOrAlias}", + v3mux.Handle("/peek/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -165,12 +218,12 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) } - r0mux.Handle("/joined_rooms", + v3mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/join", + v3mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -184,7 +237,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/leave", + v3mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -198,7 +251,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/unpeek", + v3mux.Handle("/rooms/{roomID}/unpeek", httputil.MakeAuthAPI("unpeek", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -209,7 +262,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/ban", + v3mux.Handle("/rooms/{roomID}/ban", httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -218,7 +271,7 @@ func Setup( return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/invite", + v3mux.Handle("/rooms/{roomID}/invite", httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -230,7 +283,7 @@ func Setup( return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/kick", + v3mux.Handle("/rooms/{roomID}/kick", httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -239,7 +292,7 @@ func Setup( return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/unban", + v3mux.Handle("/rooms/{roomID}/unban", httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -248,7 +301,7 @@ func Setup( return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/send/{eventType}", + v3mux.Handle("/rooms/{roomID}/send/{eventType}", httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -257,7 +310,7 @@ func Setup( return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", + v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -268,7 +321,7 @@ func Setup( nil, cfg, rsAPI, transactionsCache) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/event/{eventID}", + v3mux.Handle("/rooms/{roomID}/event/{eventID}", httputil.MakeAuthAPI("rooms_get_event", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -278,7 +331,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -286,7 +339,7 @@ func Setup( return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -294,7 +347,7 @@ func Setup( return GetAliases(req, rsAPI, device, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -305,7 +358,7 @@ func Setup( return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -314,7 +367,7 @@ func Setup( return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", + v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -326,7 +379,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", + v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -337,21 +390,21 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r } return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r } return RegisterAvailable(req, cfg, accountDB) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -361,7 +414,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeAuthAPI("directory_room", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -371,7 +424,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeAuthAPI("directory_room", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -380,7 +433,7 @@ func Setup( return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI) }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/directory/list/room/{roomID}", + v3mux.Handle("/directory/list/room/{roomID}", httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -390,7 +443,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) // TODO: Add AS support - r0mux.Handle("/directory/list/room/{roomID}", + v3mux.Handle("/directory/list/room/{roomID}", httputil.MakeAuthAPI("directory_list", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -399,25 +452,25 @@ func Setup( return SetVisibility(req, rsAPI, device, vars["roomID"]) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/publicRooms", + v3mux.Handle("/publicRooms", httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/logout", + v3mux.Handle("/logout", httputil.MakeAuthAPI("logout", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Logout(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/logout/all", + v3mux.Handle("/logout/all", httputil.MakeAuthAPI("logout", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return LogoutAll(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/typing/{userID}", + v3mux.Handle("/rooms/{roomID}/typing/{userID}", httputil.MakeAuthAPI("rooms_typing", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -429,7 +482,7 @@ func Setup( return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/redact/{eventID}", + v3mux.Handle("/rooms/{roomID}/redact/{eventID}", httputil.MakeAuthAPI("rooms_redact", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -438,7 +491,7 @@ func Setup( return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", + v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", httputil.MakeAuthAPI("rooms_redact", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -448,7 +501,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/sendToDevice/{eventType}/{txnID}", + v3mux.Handle("/sendToDevice/{eventType}/{txnID}", httputil.MakeAuthAPI("send_to_device", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -473,7 +526,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/account/whoami", + v3mux.Handle("/account/whoami", httputil.MakeAuthAPI("whoami", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -482,7 +535,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/account/password", + v3mux.Handle("/account/password", httputil.MakeAuthAPI("password", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -491,7 +544,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/account/deactivate", + v3mux.Handle("/account/deactivate", httputil.MakeAuthAPI("deactivate", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -502,7 +555,7 @@ func Setup( // Stub endpoints required by Element - r0mux.Handle("/login", + v3mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -511,14 +564,14 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/auth/{authType}/fallback/web", + v3mux.Handle("/auth/{authType}/fallback/web", httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) return AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/pushrules/", + v3mux.Handle("/pushrules/", httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { // TODO: Implement push rules API res := json.RawMessage(`{ @@ -539,7 +592,7 @@ func Setup( // Element user settings - r0mux.Handle("/profile/{userID}", + v3mux.Handle("/profile/{userID}", httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -549,7 +602,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/avatar_url", + v3mux.Handle("/profile/{userID}/avatar_url", httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -559,7 +612,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/avatar_url", + v3mux.Handle("/profile/{userID}/avatar_url", httputil.MakeAuthAPI("profile_avatar_url", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -574,7 +627,7 @@ func Setup( // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method - r0mux.Handle("/profile/{userID}/displayname", + v3mux.Handle("/profile/{userID}/displayname", httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -584,7 +637,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/displayname", + v3mux.Handle("/profile/{userID}/displayname", httputil.MakeAuthAPI("profile_displayname", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -599,13 +652,13 @@ func Setup( // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method - r0mux.Handle("/account/3pid", + v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAssociated3PIDs(req, accountDB, device) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/account/3pid", + v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) }), @@ -617,14 +670,14 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", + v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { return RequestEmailToken(req, accountDB, cfg) }), ).Methods(http.MethodPost, http.MethodOptions) // Element logs get flooded unless this is handled - r0mux.Handle("/presence/{userID}/status", + v3mux.Handle("/presence/{userID}/status", httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -637,7 +690,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/voip/turnServer", + v3mux.Handle("/voip/turnServer", httputil.MakeAuthAPI("turn_server", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -646,7 +699,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/thirdparty/protocols", + v3mux.Handle("/thirdparty/protocols", httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { // TODO: Return the third party protcols return util.JSONResponse{ @@ -656,7 +709,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/initialSync", + v3mux.Handle("/rooms/{roomID}/initialSync", httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { // TODO: Allow people to peek into rooms. return util.JSONResponse{ @@ -666,7 +719,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userID}/account_data/{type}", + v3mux.Handle("/user/{userID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -676,7 +729,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", + v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -686,7 +739,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userID}/account_data/{type}", + v3mux.Handle("/user/{userID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -696,7 +749,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", + v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -706,7 +759,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/admin/whois/{userID}", + v3mux.Handle("/admin/whois/{userID}", httputil.MakeAuthAPI("admin_whois", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -716,7 +769,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/user/{userID}/openid/request_token", + v3mux.Handle("/user/{userID}/openid/request_token", httputil.MakeAuthAPI("openid_request_token", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -729,7 +782,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/user_directory/search", + v3mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -754,7 +807,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/members", + v3mux.Handle("/rooms/{roomID}/members", httputil.MakeAuthAPI("rooms_members", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -764,7 +817,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/joined_members", + v3mux.Handle("/rooms/{roomID}/joined_members", httputil.MakeAuthAPI("rooms_members", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -774,7 +827,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/read_markers", + v3mux.Handle("/rooms/{roomID}/read_markers", httputil.MakeAuthAPI("rooms_read_markers", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -787,7 +840,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/forget", + v3mux.Handle("/rooms/{roomID}/forget", httputil.MakeAuthAPI("rooms_forget", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -800,13 +853,13 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/devices", + v3mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetDevicesByLocalpart(req, userAPI, device) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("get_device", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -816,7 +869,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("device_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -826,7 +879,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("delete_device", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -836,14 +889,14 @@ func Setup( }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/delete_devices", + v3mux.Handle("/delete_devices", httputil.MakeAuthAPI("delete_devices", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return DeleteDevices(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) // Stub implementations for sytest - r0mux.Handle("/events", + v3mux.Handle("/events", httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, @@ -853,7 +906,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/initialSync", + v3mux.Handle("/initialSync", httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", @@ -861,7 +914,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", httputil.MakeAuthAPI("get_tags", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -871,7 +924,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", httputil.MakeAuthAPI("put_tag", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -881,7 +934,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", httputil.MakeAuthAPI("delete_tag", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -891,7 +944,7 @@ func Setup( }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/capabilities", + v3mux.Handle("/capabilities", httputil.MakeAuthAPI("capabilities", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -934,11 +987,11 @@ func Setup( return CreateKeyBackupVersion(req, userAPI, device) }) - r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) - r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) - r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + v3mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + v3mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) @@ -1030,9 +1083,9 @@ func Setup( return UploadBackupKeys(req, userAPI, device, version, &keyReq) }) - r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) - r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) - r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) @@ -1060,9 +1113,9 @@ func Setup( return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) }) - r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) @@ -1080,29 +1133,29 @@ func Setup( return UploadCrossSigningDeviceSignatures(req, keyAPI, device) }) - r0mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) // Supplying a device ID is deprecated. - r0mux.Handle("/keys/upload/{deviceID}", + v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/upload", + v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/query", + v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/claim", + v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return ClaimKeys(req, keyAPI) }), diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 606107b9f..23935b5d9 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -15,10 +15,16 @@ package routing import ( + "context" "net/http" "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" @@ -26,10 +32,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" ) // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid @@ -97,7 +99,22 @@ func SendEvent( defer mutex.(*sync.Mutex).Unlock() startedGeneratingEvent := time.Now() - e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI) + + var r map[string]interface{} // must be a JSON object + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + evTime, err := httputil.ParseTSParam(req) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(err.Error()), + } + } + + e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, cfg, rsAPI, evTime) if resErr != nil { return *resErr } @@ -153,27 +170,16 @@ func SendEvent( } func generateSendEvent( - req *http.Request, + ctx context.Context, + r map[string]interface{}, device *userapi.Device, roomID, eventType string, stateKey *string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, + evTime time.Time, ) (*gomatrixserverlib.Event, *util.JSONResponse) { // parse the incoming http request userID := device.UserID - var r map[string]interface{} // must be a JSON object - resErr := httputil.UnmarshalJSONRequest(req, &r) - if resErr != nil { - return nil, resErr - } - - evTime, err := httputil.ParseTSParam(req) - if err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), - } - } // create the new event and set all the fields we can builder := gomatrixserverlib.EventBuilder{ @@ -182,15 +188,15 @@ func generateSendEvent( Type: eventType, StateKey: stateKey, } - err = builder.SetContent(r) + err := builder.SetContent(r) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") + util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") resErr := jsonerror.InternalServerError() return nil, &resErr } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, @@ -213,7 +219,7 @@ func generateSendEvent( JSON: jsonerror.BadJSON(e.Error()), } } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed") resErr := jsonerror.InternalServerError() return nil, &resErr } diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 3abf3db27..fd214b34b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/util" ) @@ -33,7 +33,7 @@ type typingContentJSON struct { // sends the typing events to client API typingProducer func SendTyping( req *http.Request, device *userapi.Device, roomID string, - userID string, accountDB accounts.Database, + userID string, accountDB userdb.Database, eduAPI api.EDUServerInputAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go new file mode 100644 index 000000000..42a303a6b --- /dev/null +++ b/clientapi/routing/server_notices.go @@ -0,0 +1,343 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + userdb "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/tokens" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/transactions" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +// Unspecced server notice request +// https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/server_notices.md +type sendServerNoticeRequest struct { + UserID string `json:"user_id,omitempty"` + Content struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + } `json:"content,omitempty"` + Type string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` +} + +// SendServerNotice sends a message to a specific user. It can only be invoked by an admin. +func SendServerNotice( + req *http.Request, + cfgNotices *config.ServerNotices, + cfgClient *config.ClientAPI, + userAPI userapi.UserInternalAPI, + rsAPI api.RoomserverInternalAPI, + accountsDB userdb.Database, + asAPI appserviceAPI.AppServiceQueryAPI, + device *userapi.Device, + senderDevice *userapi.Device, + txnID *string, + txnCache *transactions.Cache, +) util.JSONResponse { + if device.AccountType != userapi.AccountTypeAdmin { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("This API can only be used by admin users."), + } + } + + if txnID != nil { + // Try to fetch response from transactionsCache + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + return *res + } + } + + ctx := req.Context() + var r sendServerNoticeRequest + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + // check that all required fields are set + if !r.valid() { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Invalid request"), + } + } + + // get rooms for specified user + allUserRooms := []string{} + userRooms := api.QueryRoomsForUserResponse{} + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "join", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + // get invites for specified user + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "invite", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + // get left rooms for specified user + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "leave", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + + // get rooms of the sender + senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) + senderRooms := api.QueryRoomsForUserResponse{} + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: senderUserID, + WantMembership: "join", + }, &senderRooms); err != nil { + return util.ErrorResponse(err) + } + + // check if we have rooms in common + commonRooms := []string{} + for _, userRoomID := range allUserRooms { + for _, senderRoomID := range senderRooms.RoomIDs { + if userRoomID == senderRoomID { + commonRooms = append(commonRooms, senderRoomID) + } + } + } + + if len(commonRooms) > 1 { + return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms))) + } + + var ( + roomID string + roomVersion = gomatrixserverlib.RoomVersionV6 + ) + + // create a new room for the user + if len(commonRooms) == 0 { + powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID) + powerLevelContent.Users[r.UserID] = -10 // taken from Synapse + pl, err := json.Marshal(powerLevelContent) + if err != nil { + return util.ErrorResponse(err) + } + createContent := map[string]interface{}{} + createContent["m.federate"] = false + cc, err := json.Marshal(createContent) + if err != nil { + return util.ErrorResponse(err) + } + crReq := createRoomRequest{ + Invite: []string{r.UserID}, + Name: cfgNotices.RoomName, + Visibility: "private", + Preset: presetPrivateChat, + CreationContent: cc, + GuestCanJoin: false, + RoomVersion: roomVersion, + PowerLevelContentOverride: pl, + } + + roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, accountsDB, rsAPI, asAPI, time.Now()) + + switch data := roomRes.JSON.(type) { + case createRoomResponse: + roomID = data.RoomID + + // tag the room, so we can later check if the user tries to reject an invite + serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{ + "m.server_notice": { + Order: 1.0, + }, + }} + if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil { + util.GetLogger(ctx).WithError(err).Error("saveTagData failed") + return jsonerror.InternalServerError() + } + + default: + // if we didn't get a createRoomResponse, we probably received an error, so return that. + return roomRes + } + + } else { + // we've found a room in common, check the membership + roomID = commonRooms[0] + // re-invite the user + res, err := sendInvite(ctx, accountsDB, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) + if err != nil { + return res + } + } + + startedGeneratingEvent := time.Now() + + request := map[string]interface{}{ + "body": r.Content.Body, + "msgtype": r.Content.MsgType, + } + e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now()) + if resErr != nil { + logrus.Errorf("failed to send message: %+v", resErr) + return *resErr + } + timeToGenerateEvent := time.Since(startedGeneratingEvent) + + var txnAndSessionID *api.TransactionID + if txnID != nil { + txnAndSessionID = &api.TransactionID{ + TransactionID: *txnID, + SessionID: device.SessionID, + } + } + + // pass the new event to the roomserver and receive the correct event ID + // event ID in case of duplicate transaction is discarded + startedSubmittingEvent := time.Now() + if err := api.SendEvents( + ctx, rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + e.Headered(roomVersion), + }, + cfgClient.Matrix.ServerName, + cfgClient.Matrix.ServerName, + txnAndSessionID, + false, + ); err != nil { + util.GetLogger(ctx).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError() + } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "event_id": e.EventID(), + "room_id": roomID, + "room_version": roomVersion, + }).Info("Sent event to roomserver") + timeToSubmitEvent := time.Since(startedSubmittingEvent) + + res := util.JSONResponse{ + Code: http.StatusOK, + JSON: sendEventResponse{e.EventID()}, + } + // Add response to transactionsCache + if txnID != nil { + txnCache.AddTransaction(device.AccessToken, *txnID, &res) + } + + // Take a note of how long it took to generate the event vs submit + // it to the roomserver. + sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds())) + sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds())) + + return res +} + +func (r sendServerNoticeRequest) valid() (ok bool) { + if r.UserID == "" { + return false + } + if r.Content.MsgType == "" || r.Content.Body == "" { + return false + } + return true +} + +// getSenderDevice creates a user account to be used when sending server notices. +// It returns an userapi.Device, which is used for building the event +func getSenderDevice( + ctx context.Context, + userAPI userapi.UserInternalAPI, + accountDB userdb.Database, + cfg *config.ClientAPI, +) (*userapi.Device, error) { + var accRes userapi.PerformAccountCreationResponse + // create account if it doesn't exist + err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: cfg.Matrix.ServerNotices.LocalPart, + OnConflict: userapi.ConflictUpdate, + }, &accRes) + if err != nil { + return nil, err + } + + // set the avatarurl for the user + if err = accountDB.SetAvatarURL(ctx, cfg.Matrix.ServerNotices.LocalPart, cfg.Matrix.ServerNotices.AvatarURL); err != nil { + util.GetLogger(ctx).WithError(err).Error("accountDB.SetAvatarURL failed") + return nil, err + } + + // Check if we got existing devices + deviceRes := &userapi.QueryDevicesResponse{} + err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{ + UserID: accRes.Account.UserID, + }, deviceRes) + if err != nil { + return nil, err + } + + if len(deviceRes.Devices) > 0 { + return &deviceRes.Devices[0], nil + } + + // create an AccessToken + token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ + ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(), + ServerName: string(cfg.Matrix.ServerName), + UserID: accRes.Account.UserID, + }) + if err != nil { + return nil, err + } + + // create a new device, if we didn't find any + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + Localpart: cfg.Matrix.ServerNotices.LocalPart, + DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, + AccessToken: token, + NoDeviceListUpdate: true, + }, &devRes) + + if err != nil { + return nil, err + } + return devRes.Device, nil +} diff --git a/clientapi/routing/server_notices_test.go b/clientapi/routing/server_notices_test.go new file mode 100644 index 000000000..2fac072cd --- /dev/null +++ b/clientapi/routing/server_notices_test.go @@ -0,0 +1,83 @@ +package routing + +import ( + "testing" +) + +func Test_sendServerNoticeRequest_validate(t *testing.T) { + type fields struct { + UserID string `json:"user_id,omitempty"` + Content struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + } `json:"content,omitempty"` + Type string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` + } + + content := struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + }{ + MsgType: "m.text", + Body: "Hello world!", + } + + tests := []struct { + name string + fields fields + wantOk bool + }{ + { + name: "empty request", + fields: fields{}, + }, + { + name: "msgtype empty", + fields: fields{ + UserID: "@alice:localhost", + Content: struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + }{ + Body: "Hello world!", + }, + }, + }, + { + name: "msg body empty", + fields: fields{ + UserID: "@alice:localhost", + }, + }, + { + name: "statekey empty", + fields: fields{ + UserID: "@alice:localhost", + Content: content, + }, + wantOk: true, + }, + { + name: "type empty", + fields: fields{ + UserID: "@alice:localhost", + Content: content, + }, + wantOk: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := sendServerNoticeRequest{ + UserID: tt.fields.UserID, + Content: tt.fields.Content, + Type: tt.fields.Type, + StateKey: tt.fields.StateKey, + } + if gotOk := r.valid(); gotOk != tt.wantOk { + t.Errorf("valid() = %v, want %v", gotOk, tt.wantOk) + } + }) + } +} diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index f4d233798..d89b62953 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -40,7 +40,7 @@ type threePIDsResponse struct { // RequestEmailToken implements: // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse { +func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf Code: http.StatusBadRequest, JSON: jsonerror.MatrixError{ ErrCode: "M_THREEPID_IN_USE", - Err: accounts.Err3PIDInUse.Error(), + Err: userdb.Err3PIDInUse.Error(), }, } } @@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, accountDB userdb.Database, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest @@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation( // GetAssociated3PIDs implements GET /account/3pid func GetAssociated3PIDs( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, accountDB userdb.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -170,7 +170,7 @@ func GetAssociated3PIDs( } // Forget3PID implements POST /account/3pid/delete -func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse { +func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse { var body authtypes.ThreePID if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr diff --git a/clientapi/routing/whoami.go b/clientapi/routing/whoami.go index 26280f6cc..a1d9d6675 100644 --- a/clientapi/routing/whoami.go +++ b/clientapi/routing/whoami.go @@ -21,7 +21,9 @@ import ( // whoamiResponse represents an response for a `whoami` request type whoamiResponse struct { - UserID string `json:"user_id"` + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + IsGuest bool `json:"is_guest"` } // Whoami implements `/account/whoami` which enables client to query their account user id. @@ -29,6 +31,10 @@ type whoamiResponse struct { func Whoami(req *http.Request, device *api.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusOK, - JSON: whoamiResponse{UserID: device.UserID}, + JSON: whoamiResponse{ + UserID: device.UserID, + DeviceID: device.ID, + IsGuest: device.AccountType == api.AccountTypeGuest, + }, } } diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index db62ce060..9d9a2ba7a 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -29,7 +29,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -87,7 +87,7 @@ var ( func CheckAndProcessInvite( ctx context.Context, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, db accounts.Database, + rsAPI api.RoomserverInternalAPI, db userdb.Database, roomID string, evTime time.Time, ) (inviteStoredOnIDServer bool, err error) { @@ -137,7 +137,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, + db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, + db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index d13ef31fa..d6edead80 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -23,12 +23,14 @@ import ( "os" "strings" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "golang.org/x/term" + + "github.com/matrix-org/dendrite/setup" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" + userdb "github.com/matrix-org/dendrite/userapi/storage" ) const usage = `Usage: %s @@ -57,6 +59,7 @@ var ( pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") ) func main() { @@ -74,19 +77,28 @@ func main() { pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, - }, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) + accountDB, err := userdb.NewDatabase( + &config.DatabaseOptions{ + ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, + }, + cfg.Global.ServerName, bcrypt.DefaultCost, + cfg.UserAPI.OpenIDTokenLifetimeMS, + api.DefaultLoginTokenLifetime, + ) if err != nil { logrus.Fatalln("Failed to connect to the database:", err.Error()) } + accType := api.AccountTypeUser + if *isAdmin { + accType = api.AccountTypeAdmin + } policyVersion := "" if cfg.Global.UserConsentOptions.Enabled { policyVersion = cfg.Global.UserConsentOptions.Version } - _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion) + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion, accType) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) } diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 9d4ff6a8f..d09f1cb55 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -126,7 +126,6 @@ func main() { cfg.FederationAPI.FederationMaxRetries = 6 cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index d6e6dcf80..18c047445 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -160,7 +160,6 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 076485da7..2bd7e95b7 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -79,7 +79,6 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index e73f300c6..7d3e15af0 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -132,6 +132,7 @@ func main() { // dependency. Other components also need updating after their dependencies are up. rsImpl.SetFederationAPI(fsAPI, keyRing) rsImpl.SetAppserviceAPI(asAPI) + rsImpl.SetUserAPI(userAPI) keyImpl.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index b758c740c..3c706a33d 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -164,7 +164,6 @@ func startup() { cfg.Defaults(true) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index b68b4b8cd..42ce97b15 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -167,7 +167,6 @@ func main() { cfg.Defaults(true) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 60729672e..ba5a87a7a 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -32,7 +32,6 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI) } cfg.Global.TrustedIDServers = []string{ "matrix.org", @@ -91,6 +90,7 @@ func main() { cfg.Logging[0].Type = "std" cfg.UserAPI.BCryptCost = bcrypt.MinCost cfg.Global.JetStream.InMemory = true + cfg.ClientAPI.RegistrationSharedSecret = "complement" } j, err := yaml.Marshal(cfg) diff --git a/cmd/goose/main.go b/cmd/goose/main.go index 8ed5cbd91..31a5b0050 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -8,12 +8,11 @@ import ( "log" "os" - pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" - slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" - pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas" - sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" "github.com/pressly/goose" + pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -26,8 +25,7 @@ const ( RoomServer = "roomserver" SigningKeyServer = "signingkeyserver" SyncAPI = "syncapi" - UserAPIAccounts = "userapi_accounts" - UserAPIDevices = "userapi_devices" + UserAPI = "userapi" ) var ( @@ -35,7 +33,7 @@ var ( flags = flag.NewFlagSet("goose", flag.ExitOnError) component = flags.String("component", "", "dendrite component name") knownDBs = []string{ - AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices, + AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI, } ) @@ -143,18 +141,14 @@ Commands: func loadSQLiteDeltas(component string) { switch component { - case UserAPIAccounts: - slaccounts.LoadFromGoose() - case UserAPIDevices: - sldevices.LoadFromGoose() + case UserAPI: + slusers.LoadFromGoose() } } func loadPostgresDeltas(component string) { switch component { - case UserAPIAccounts: - pgaccounts.LoadFromGoose() - case UserAPIDevices: - pgdevices.LoadFromGoose() + case UserAPI: + pgusers.LoadFromGoose() } } diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 250b76b5f..e2e29a982 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -68,6 +68,18 @@ global: # to other servers and the federation API will not be exposed. disable_federation: false + # Server notices allows server admins to send messages to all users. + server_notices: + enabled: false + # The server localpart to be used when sending notices, ensure this is not yet taken + local_part: "_server" + # The displayname to be used when sending notices + display_name: "Server alerts" + # The mxid of the avatar to use + avatar_url: "" + # The roomname to be used when creating messages + room_name: "Server Alerts" + # Consent tracking configuration user_consent: # If the user consent tracking is enabled or not @@ -169,6 +181,10 @@ client_api: # using the registration shared secret below. registration_disabled: false + # Prevents new guest accounts from being created. Guest registration is also + # disabled implicitly by setting 'registration_disabled' above. + guests_disabled: true + # If set, allows registration by anyone who knows the shared secret, regardless of # whether registration is otherwise disabled. registration_shared_secret: "" @@ -231,13 +247,6 @@ federation_api: # enable this option in production as it presents a security risk! disable_tls_validation: false - # Use the following proxy server for outbound federation traffic. - proxy_outbound: - enabled: false - protocol: http - host: localhost - port: 8080 - # Perspective keyservers to use as a backup when direct key fetches fail. This may # be required to satisfy key requests for servers that are no longer online when # joining some rooms. diff --git a/go.mod b/go.mod index 11f5b0608..2316096df 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/matrix-org/dendrite -replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 +replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c @@ -45,7 +45,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.10 github.com/morikuni/aec v1.0.0 // indirect github.com/nats-io/nats-server/v2 v2.3.2 - github.com/nats-io/nats.go v1.13.1-0.20211122170419-d7c1d78a50fc + github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31 diff --git a/go.sum b/go.sum index 8732d27ec..e79015e51 100644 --- a/go.sum +++ b/go.sum @@ -1122,8 +1122,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= -github.com/nats-io/jwt/v2 v2.2.0 h1:Yg/4WFK6vsqMudRg91eBb7Dh6XeVcDMPHycDE8CfltE= -github.com/nats-io/jwt/v2 v2.2.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= +github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY= +github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -1132,8 +1132,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 h1:BLQVdjMH5XD4BYb0fa+c2Oh2Nr1vrO7GKvRnIJDxChc= -github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423/go.mod h1:9sdEkBhyZMQG1M9TevnlYUwMusRACn2vlgOeqoHKwVo= +github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8= +github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= @@ -1508,8 +1508,8 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1735,6 +1735,7 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1756,10 +1757,10 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index bf4fe85ed..6d413093f 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -7,14 +7,6 @@ import ( ) const ( - RoomServerStateKeyNIDsCacheName = "roomserver_statekey_nids" - RoomServerStateKeyNIDsCacheMaxEntries = 1024 - RoomServerStateKeyNIDsCacheMutable = false - - RoomServerEventTypeNIDsCacheName = "roomserver_eventtype_nids" - RoomServerEventTypeNIDsCacheMaxEntries = 64 - RoomServerEventTypeNIDsCacheMutable = false - RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMutable = false @@ -29,44 +21,10 @@ type RoomServerCaches interface { // RoomServerNIDsCache contains the subset of functions needed for // a roomserver NID cache. type RoomServerNIDsCache interface { - GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) - StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) - - GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) - StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) - GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) } -func (c Caches) GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) { - val, found := c.RoomServerStateKeyNIDs.Get(stateKey) - if found && val != nil { - if stateKeyNID, ok := val.(types.EventStateKeyNID); ok { - return stateKeyNID, true - } - } - return 0, false -} - -func (c Caches) StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) { - c.RoomServerStateKeyNIDs.Set(stateKey, nid) -} - -func (c Caches) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) { - val, found := c.RoomServerEventTypeNIDs.Get(eventType) - if found && val != nil { - if eventTypeNID, ok := val.(types.EventTypeNID); ok { - return eventTypeNID, true - } - } - return 0, false -} - -func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) { - c.RoomServerEventTypeNIDs.Set(eventType, nid) -} - func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) if found && val != nil { diff --git a/internal/caching/caches.go b/internal/caching/caches.go index f04d05d42..e1642a663 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -4,14 +4,12 @@ package caching // different implementations as long as they satisfy the Cache // interface. type Caches struct { - RoomVersions Cache // RoomVersionCache - ServerKeys Cache // ServerKeyCache - RoomServerStateKeyNIDs Cache // RoomServerNIDsCache - RoomServerEventTypeNIDs Cache // RoomServerNIDsCache - RoomServerRoomNIDs Cache // RoomServerNIDsCache - RoomServerRoomIDs Cache // RoomServerNIDsCache - RoomInfos Cache // RoomInfoCache - FederationEvents Cache // FederationEventsCache + RoomVersions Cache // RoomVersionCache + ServerKeys Cache // ServerKeyCache + RoomServerRoomNIDs Cache // RoomServerNIDsCache + RoomServerRoomIDs Cache // RoomServerNIDsCache + RoomInfos Cache // RoomInfoCache + FederationEvents Cache // FederationEventsCache } // Cache is the interface that an implementation must satisfy. diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index f0915d7ca..ccb92852b 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -28,24 +28,6 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { if err != nil { return nil, err } - roomServerStateKeyNIDs, err := NewInMemoryLRUCachePartition( - RoomServerStateKeyNIDsCacheName, - RoomServerStateKeyNIDsCacheMutable, - RoomServerStateKeyNIDsCacheMaxEntries, - enablePrometheus, - ) - if err != nil { - return nil, err - } - roomServerEventTypeNIDs, err := NewInMemoryLRUCachePartition( - RoomServerEventTypeNIDsCacheName, - RoomServerEventTypeNIDsCacheMutable, - RoomServerEventTypeNIDsCacheMaxEntries, - enablePrometheus, - ) - if err != nil { - return nil, err - } roomServerRoomIDs, err := NewInMemoryLRUCachePartition( RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheMutable, @@ -74,18 +56,15 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { return nil, err } go cacheCleaner( - roomVersions, serverKeys, roomServerStateKeyNIDs, - roomServerEventTypeNIDs, roomServerRoomIDs, + roomVersions, serverKeys, roomServerRoomIDs, roomInfos, federationEvents, ) return &Caches{ - RoomVersions: roomVersions, - ServerKeys: serverKeys, - RoomServerStateKeyNIDs: roomServerStateKeyNIDs, - RoomServerEventTypeNIDs: roomServerEventTypeNIDs, - RoomServerRoomIDs: roomServerRoomIDs, - RoomInfos: roomInfos, - FederationEvents: federationEvents, + RoomVersions: roomVersions, + ServerKeys: serverKeys, + RoomServerRoomIDs: roomServerRoomIDs, + RoomInfos: roomInfos, + FederationEvents: federationEvents, }, nil } diff --git a/internal/test/config.go b/internal/test/config.go index 4fb6a946c..0372fb9c6 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con cfg.RoomServer.Database.ConnectionString = config.DataSource(database) cfg.SyncAPI.Database.ConnectionString = config.DataSource(database) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database) cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() cfg.EDUServer.InternalAPI.Listen = assignAddress() diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 1b6e2d428..c5a5d40c7 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -367,10 +367,13 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam waitTime = fcerr.RetryAfter } else if fcerr.Blacklisted { waitTime = time.Hour * 8 + } else { + // For all other errors (DNS resolution, network etc.) wait 1 hour. + waitTime = time.Hour } } else { waitTime = time.Hour - logger.WithError(err).Warn("GetUserDevices returned unknown error type") + logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type") } continue } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 2536c1f76..0c264b718 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne } func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil) + msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), @@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques domain := string(serverName) // query local devices if serverName == a.ThisServer { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), @@ -513,6 +513,11 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( // drop the error as it's already a failure at this point _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) } + + // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + if len(res.DeviceKeys) > 0 { + delete(res.Failures, serverName) + } respMu.Unlock() } @@ -520,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. if err != nil { return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) @@ -549,10 +554,58 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( } func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + // get a list of devices from the user API that actually exist, as + // we won't store keys for devices that don't exist + uapidevices := &userapi.QueryDevicesResponse{} + if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return + } + if !uapidevices.UserExists { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("user %q does not exist", req.UserID), + } + return + } + existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices)) + for _, key := range uapidevices.Devices { + existingDeviceMap[key.ID] = struct{}{} + } + + // Get all of the user existing device keys so we can check for changes. + existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), + } + return + } + + // Work out whether we have device keys in the keyserver for devices that + // no longer exist in the user API. This is mostly an exercise to ensure + // that we keep some integrity between the two. + var toClean []gomatrixserverlib.KeyID + for _, k := range existingKeys { + if _, ok := existingDeviceMap[k.DeviceID]; !ok { + toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) + } + } + + if len(toClean) > 0 { + if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) + } else { + logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) + } + } + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { - _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) + var serverName gomatrixserverlib.ServerName + _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) if err != nil { continue // ignore invalid users } @@ -563,6 +616,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per keysToStore = append(keysToStore, key.WithStreamID(0)) continue // deleted keys don't need sanity checking } + // check that the device in question actually exists in the user + // API before we try and store a key for it + if _, ok := existingDeviceMap[key.DeviceID]; !ok { + continue + } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { @@ -578,29 +636,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per }) } - // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceMessage, len(keysToStore)) - for i := range keysToStore { - existingKeys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, - }, - } - } - if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), - } - return - } if req.OnlyDisplayNameUpdates { // add the display name field from keysToStore into existingKeys keysToStore = appendDisplayNames(existingKeys, keysToStore) } // store the device keys and emit changes - err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 0110860ea..4dffe695c 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -53,7 +53,7 @@ type Database interface { // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // cross-signing signatures relating to that device. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 5ae0da969..628301cf7 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - countStreamIDsForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { + return nil, err + } if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } @@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql return err } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) if err != nil { return nil, err } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5914d28e1..deee76eb4 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) } func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index fa1c930db..b461424c6 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { + return nil, err + } if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } @@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql return err } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true } - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) if err != nil { return nil, err } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index c4c99d8c4..4d5137249 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { } // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false) if err != nil { t.Fatalf("DeviceKeysForUser returned error: %s", err) } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index e44757e1a..ff70a2366 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -38,7 +38,7 @@ type DeviceKeys interface { InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index e6d37e8f1..bcbf0e4f9 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -3,9 +3,11 @@ package api import ( "context" + "github.com/matrix-org/gomatrixserverlib" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/gomatrixserverlib" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // RoomserverInputAPI is used to write events to the room server. @@ -14,6 +16,7 @@ type RoomserverInternalAPI interface { // interdependencies between the roomserver and other input APIs SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) + SetUserAPI(userAPI userapi.UserInternalAPI) InputRoomEvents( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 16f52abb7..88b372154 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -5,10 +5,12 @@ import ( "encoding/json" "fmt" - asAPI "github.com/matrix-org/dendrite/appservice/api" - fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + asAPI "github.com/matrix-org/dendrite/appservice/api" + fsAPI "github.com/matrix-org/dendrite/federationapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the @@ -25,6 +27,10 @@ func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceQuer t.Impl.SetAppserviceAPI(asAPI) } +func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.UserInternalAPI) { + t.Impl.SetUserAPI(userAPI) +} + func (t *RoomserverInternalAPITrace) InputRoomEvents( ctx context.Context, req *InputRoomEventsRequest, diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 51cbcb1ad..d640858a6 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -95,6 +95,8 @@ type PerformLeaveRequest struct { } type PerformLeaveResponse struct { + Code int `json:"code,omitempty"` + Message interface{} `json:"message,omitempty"` } type PerformInviteRequest struct { diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index fd963ad83..10c8c844e 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -14,6 +14,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -32,6 +34,7 @@ type RoomserverInternalAPI struct { *perform.Publisher *perform.Backfiller *perform.Forgetter + ProcessContext *process.ProcessContext DB storage.Database Cfg *config.RoomServer Cache caching.RoomServerCaches @@ -48,12 +51,13 @@ type RoomserverInternalAPI struct { } func NewRoomserverAPI( - cfg *config.RoomServer, roomserverDB storage.Database, consumer nats.JetStreamContext, - inputRoomEventTopic, outputRoomEventTopic string, caches caching.RoomServerCaches, - perspectiveServerNames []gomatrixserverlib.ServerName, + processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database, + consumer nats.JetStreamContext, inputRoomEventTopic, outputRoomEventTopic string, + caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName, ) *RoomserverInternalAPI { serverACLs := acls.NewServerACLs(roomserverDB) a := &RoomserverInternalAPI{ + ProcessContext: processCtx, DB: roomserverDB, Cfg: cfg, Cache: caches, @@ -83,6 +87,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA r.KeyRing = keyRing r.Inputer = &input.Inputer{ + ProcessContext: r.ProcessContext, DB: r.DB, InputRoomEventTopic: r.InputRoomEventTopic, OutputRoomEventTopic: r.OutputRoomEventTopic, @@ -155,6 +160,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA } } +func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) { + r.Leaver.UserAPI = userAPI +} + func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { r.asAPI = asAPI } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 5bdec0a24..22e4b67a0 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" @@ -59,6 +60,7 @@ var keyContentFields = map[string]string{ } type Inputer struct { + ProcessContext *process.ProcessContext DB storage.Database JetStream nats.JetStreamContext Durable nats.SubOpt @@ -115,7 +117,7 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent) + action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 5173d3ab2..ae28ebefa 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -405,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro if len(extraEventIDs) == 0 { return nil, nil } - extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs) + extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 12784e5f5..49ddd4810 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -16,25 +16,29 @@ package perform import ( "context" + "encoding/json" "fmt" "strings" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type Leaver struct { - Cfg *config.RoomServer - DB storage.Database - FSAPI fsAPI.FederationInternalAPI - + Cfg *config.RoomServer + DB storage.Database + FSAPI fsAPI.FederationInternalAPI + UserAPI userapi.UserInternalAPI Inputer *input.Inputer } @@ -85,6 +89,31 @@ func (r *Leaver) performLeaveRoomByID( if host != r.Cfg.Matrix.ServerName { return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) } + // check that this is not a "server notice room" + accData := &userapi.QueryAccountDataResponse{} + if err := r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ + UserID: req.UserID, + RoomID: req.RoomID, + DataType: "m.tag", + }, accData); err != nil { + return nil, fmt.Errorf("unable to query account data") + } + + if roomData, ok := accData.RoomAccountData[req.RoomID]; ok { + tagData, ok := roomData["m.tag"] + if ok { + tags := gomatrix.TagContent{} + if err = json.Unmarshal(tagData, &tags); err != nil { + return nil, fmt.Errorf("unable to unmarshal tag content") + } + if _, ok = tags.Tags["m.server_notice"]; ok { + // mimic the returned values from Synapse + res.Message = "You cannot reject this invite" + res.Code = 403 + return nil, fmt.Errorf("You cannot reject this invite") + } + } + } } // There's no invite pending, so first of all we want to find out diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index a61404efe..99c596606 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -11,6 +11,8 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/opentracing/opentracing-go" ) @@ -90,6 +92,10 @@ func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.Federation func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { } +// SetUserAPI no-ops in HTTP client mode as there is no chicken/egg scenario +func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) { +} + // SetRoomAlias implements RoomserverAliasAPI func (h *httpRoomserverInternalAPI) SetRoomAlias( ctx context.Context, diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index e1b84b80c..950c6b4e7 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -53,7 +53,7 @@ func NewInternalAPI( js := jetstream.Prepare(&cfg.Matrix.JetStream) return internal.NewRoomserverAPI( - cfg, roomserverDB, js, + base.ProcessContext, cfg, roomserverDB, js, cfg.Matrix.JetStream.TopicFor(jetstream.InputRoomEvent), cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent), base.Caches, perspectiveServerNames, diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c136f039a..8012174a0 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -127,6 +127,9 @@ const bulkSelectEventIDSQL = "" + const bulkSelectEventNIDSQL = "" + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" +const bulkSelectUnsentEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" @@ -147,6 +150,7 @@ type eventStatements struct { bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt + bulkSelectUnsentEventNIDStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt } @@ -173,6 +177,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL}, }.Prepare(db) @@ -458,10 +463,28 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev return results, nil } -// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { - stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + return s.bulkSelectEventNID(ctx, txn, eventIDs, false) +} + +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID +// only for events that haven't already been sent to the roomserver output. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, true) +} + +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { + var stmt *sql.Stmt + if onlyUnsent { + stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt) + } else { + stmt = sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + } rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 66ac2f5b6..8f3f3d631 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } // Look up the NID of the new join event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } // Look up the NID of the new leave event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er } if u.membership != tables.MembershipStateKnock { // Look up the NID of the new knock event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 89b878b9d..810a18ef2 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -215,7 +215,13 @@ func (u *RoomUpdater) EventIDs( func (u *RoomUpdater) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs) + return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter) +} + +func (u *RoomUpdater) UnsentEventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly) } func (u *RoomUpdater) StateAtEventIDs( @@ -231,7 +237,11 @@ func (u *RoomUpdater) StateEntriesForEventIDs( } func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs) + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) +} + +func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) } func (u *RoomUpdater) GetMembershipEventNIDsForRoom( diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e96c77afa..b255cfb3f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -59,23 +59,12 @@ func (d *Database) eventTypeNIDs( ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) - remaining := []string{} - for _, eventType := range eventTypes { - if nid, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok { - result[eventType] = nid - } else { - remaining = append(remaining, eventType) - } + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes) + if err != nil { + return nil, err } - if len(remaining) > 0 { - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining) - if err != nil { - return nil, err - } - for eventType, nid := range nids { - result[eventType] = nid - d.Cache.StoreRoomServerEventTypeNID(eventType, nid) - } + for eventType, nid := range nids { + result[eventType] = nid } return result, nil } @@ -96,23 +85,12 @@ func (d *Database) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) - remaining := []string{} - for _, eventStateKey := range eventStateKeys { - if nid, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok { - result[eventStateKey] = nid - } else { - remaining = append(remaining, eventStateKey) - } + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) + if err != nil { + return nil, err } - if len(remaining) > 0 { - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining) - if err != nil { - return nil, err - } - for eventStateKey, nid := range nids { - result[eventStateKey] = nid - d.Cache.StoreRoomServerStateKeyNID(eventStateKey, nid) - } + for eventStateKey, nid := range nids { + result[eventStateKey] = nid } return result, nil } @@ -238,13 +216,27 @@ func (d *Database) addState( func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return d.eventNIDs(ctx, nil, eventIDs) + return d.eventNIDs(ctx, nil, eventIDs, NoFilter) } +type UnsentFilter bool + +const ( + NoFilter UnsentFilter = false + FilterUnsentOnly UnsentFilter = true +) + func (d *Database) eventNIDs( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, ) (map[string]types.EventNID, error) { - return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) + switch filter { + case FilterUnsentOnly: + return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs) + case NoFilter: + return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) + default: + panic("impossible case") + } } func (d *Database) SetState( @@ -281,11 +273,11 @@ func (d *Database) EventIDs( } func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return d.eventsFromIDs(ctx, nil, eventIDs) + return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter) } -func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.eventNIDs(ctx, txn, eventIDs) +func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { + nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) if err != nil { return nil, err } @@ -704,9 +696,6 @@ func (d *Database) assignRoomNID( func (d *Database) assignEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { - if eventTypeNID, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok { - return eventTypeNID, nil - } // Check if we already have a numeric ID in the database. eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { @@ -717,18 +706,12 @@ func (d *Database) assignEventTypeNID( eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) } } - if err == nil { - d.Cache.StoreRoomServerEventTypeNID(eventType, eventTypeNID) - } return eventTypeNID, err } func (d *Database) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { - if eventStateKeyNID, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok { - return eventStateKeyNID, nil - } // Check if we already have a numeric ID in the database. eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { @@ -739,9 +722,6 @@ func (d *Database) assignStateKeyNID( eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) } } - if err == nil { - d.Cache.StoreRoomServerStateKeyNID(eventStateKey, eventStateKeyNID) - } return eventStateKeyNID, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index cef09fe60..969a10ce5 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -99,6 +99,9 @@ const bulkSelectEventIDSQL = "" + const bulkSelectEventNIDSQL = "" + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" +const bulkSelectUnsentEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" @@ -118,8 +121,9 @@ type eventStatements struct { bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt - bulkSelectEventNIDStmt *sql.Stmt - //selectRoomNIDsForEventNIDsStmt *sql.Stmt + //bulkSelectEventNIDStmt *sql.Stmt + //bulkSelectUnsentEventNIDStmt *sql.Stmt + //selectRoomNIDsForEventNIDsStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -144,7 +148,8 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, - {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, }.Prepare(db) } @@ -494,15 +499,33 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev return results, nil } -// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, false) +} + +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID +// only for events that haven't already been sent to the roomserver output. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, true) +} + +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + var selectOrig string + if onlyUnsent { + selectOrig = strings.Replace(bulkSelectUnsentEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + } else { + selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + } selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index fed39b944..e3fed700b 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -59,6 +59,7 @@ type Events interface { // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) } diff --git a/setup/base/base.go b/setup/base/base.go index 8300850f9..5b6d1c14b 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -39,7 +39,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/gorilla/mux" @@ -274,8 +274,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. -func (b *BaseDendrite) CreateAccountsDB() accounts.Database { - db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS) +func (b *BaseDendrite) CreateAccountsDB() userdb.Database { + db, err := userdb.NewDatabase( + &b.Cfg.UserAPI.AccountDatabase, + b.Cfg.Global.ServerName, + b.Cfg.UserAPI.BCryptCost, + b.Cfg.UserAPI.OpenIDTokenLifetimeMS, + userapi.DefaultLoginTokenLifetime, + ) if err != nil { logrus.WithError(err).Panicf("failed to connect to accounts db") } diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 75f5e3df3..4590e752b 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -18,6 +18,10 @@ type ClientAPI struct { // If set, allows registration by anyone who also has the shared // secret, even if registration is otherwise disabled. RegistrationSharedSecret string `yaml:"registration_shared_secret"` + // If set, prevents guest accounts from being created. Only takes + // effect if registration is enabled, otherwise guests registration + // is forbidden either way. + GuestsDisabled bool `yaml:"guests_disabled"` // Boolean stating whether catpcha registration is enabled // and required diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 4f5f49de8..95e705033 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -29,8 +29,6 @@ type FederationAPI struct { // on remote federation endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` - Proxy Proxy `yaml:"proxy_outbound"` - // Perspective keyservers, to use as a backup when direct key fetch // requests don't succeed KeyPerspectives KeyPerspectives `yaml:"key_perspectives"` @@ -50,8 +48,6 @@ func (c *FederationAPI) Defaults(generate bool) { c.FederationMaxRetries = 16 c.DisableTLSValidation = false - - c.Proxy.Defaults() } func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index cada7bc9a..5c81ddcb2 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -65,6 +65,9 @@ type Global struct { // DNS caching options for all outbound HTTP requests DNSCache DNSCacheOptions `yaml:"dns_cache"` + // ServerNotices configuration used for sending server notices + ServerNotices ServerNotices `yaml:"server_notices"` + // Consent tracking options UserConsentOptions UserConsentOptions `yaml:"user_consent"` } @@ -84,6 +87,7 @@ func (c *Global) Defaults(generate bool) { c.DNSCache.Defaults() c.Sentry.Defaults() c.UserConsentOptions.Defaults(c.BaseURL) + c.ServerNotices.Defaults(generate) } func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -95,6 +99,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { c.Sentry.Verify(configErrs, isMonolith) c.DNSCache.Verify(configErrs, isMonolith) c.UserConsentOptions.Verify(configErrs, isMonolith) + c.ServerNotices.Verify(configErrs, isMonolith) } type OldVerifyKeys struct { @@ -136,6 +141,31 @@ func (c *Metrics) Defaults(generate bool) { func (c *Metrics) Verify(configErrs *ConfigErrors, isMonolith bool) { } +// ServerNotices defines the configuration used for sending server notices +type ServerNotices struct { + Enabled bool `yaml:"enabled"` + // The localpart to be used when sending notices + LocalPart string `yaml:"local_part"` + // The displayname to be used when sending notices + DisplayName string `yaml:"display_name"` + // The avatar of this user + AvatarURL string `yaml:"avatar"` + // The roomname to be used when creating messages + RoomName string `yaml:"room_name"` +} + +func (c *ServerNotices) Defaults(generate bool) { + if generate { + c.Enabled = true + c.LocalPart = "_server" + c.DisplayName = "Server Alert" + c.RoomName = "Server Alert" + c.AvatarURL = "" + } +} + +func (c *ServerNotices) Verify(errors *ConfigErrors, isMonolith bool) {} + // The configuration to use for Sentry error reporting type Sentry struct { Enabled bool `yaml:"enabled"` diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 5aa54929e..8f7611f0a 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -58,6 +58,11 @@ global: basic_auth: username: metrics password: metrics + server_notices: + local_part: "_server" + display_name: "Server alerts" + avatar: "" + room_name: "Server Alerts" app_service_api: internal_api: listen: http://localhost:7777 @@ -118,11 +123,6 @@ federation_sender: conn_max_lifetime: -1 send_max_retries: 16 disable_tls_validation: false - proxy_outbound: - enabled: false - protocol: http - host: localhost - port: 8080 key_server: internal_api: listen: http://localhost:7779 diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index b2cde2e96..1cb5eba18 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -16,9 +16,6 @@ type UserAPI struct { // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database"` - // The Device database stores session information for the devices of logged - // in local users. It is accessed by the UserAPI. - DeviceDatabase DatabaseOptions `yaml:"device_database"` } const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes @@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) { c.InternalAPI.Listen = "http://localhost:7781" c.InternalAPI.Connect = "http://localhost:7781" c.AccountDatabase.Defaults(10) - c.DeviceDatabase.Defaults(10) if generate { c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" - c.DeviceDatabase.ConnectionString = "file:userapi_devices.db" } c.BCryptCost = bcrypt.DefaultCost c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS @@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) - checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString)) checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 77ad2b721..562b0131e 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -24,13 +24,12 @@ func Prepare(cfg *config.JetStream) natsclient.JetStreamContext { if natsServer == nil { var err error natsServer, err = natsserver.NewServer(&natsserver.Options{ - ServerName: "monolith", - DontListen: true, - JetStream: true, - StoreDir: string(cfg.StoragePath), - NoSystemAccount: true, - AllowNewAccounts: false, - MaxPayload: 16 * 1024 * 1024, + ServerName: "monolith", + DontListen: true, + JetStream: true, + StoreDir: string(cfg.StoragePath), + NoSystemAccount: true, + MaxPayload: 16 * 1024 * 1024, }) if err != nil { panic(err) diff --git a/setup/monolith.go b/setup/monolith.go index 1ada17fca..d0325ee1e 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -38,7 +38,7 @@ import ( // all components of Dendrite, for use in monolith mode. type Monolith struct { Config *config.Dendrite - AccountDB accounts.Database + AccountDB userdb.Database KeyRing *gomatrixserverlib.KeyRing Client *gomatrixserverlib.Client FedClient *gomatrixserverlib.FederationClient diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 7fe52b728..15485bb35 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -16,6 +16,7 @@ package consumers import ( "context" + "database/sql" "encoding/json" "fmt" @@ -307,7 +308,9 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( ctx context.Context, msg api.OutputRetireInviteEvent, ) { pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID) - if err != nil { + // It's possible we just haven't heard of this invite yet, so + // we should not panic if we try to retire it. + if err != nil && err != sql.ErrNoRows { sentry.CaptureException(err) // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index b3c1da88b..d0d3ac4b0 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -39,14 +39,14 @@ func Setup( rsAPI api.RoomserverInternalAPI, cfg *config.SyncAPI, ) { - r0mux := csMux.PathPrefix("/r0").Subrouter() + v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() // TODO: Add AS support for all handlers below. - r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -54,7 +54,7 @@ func Setup( return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/filter", + v3mux.Handle("/user/{userId}/filter", httputil.MakeAuthAPI("put_filter", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -64,7 +64,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/user/{userId}/filter/{filterId}", + v3mux.Handle("/user/{userId}/filter/{filterId}", httputil.MakeAuthAPI("get_filter", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -74,7 +74,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 68c308d83..c2e8ed01c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -279,7 +279,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { parts := strings.Split(tok[1:], "_") var positions [7]StreamPosition for i, p := range parts { - if i > len(positions) { + if i >= len(positions) { break } var pos int diff --git a/sytest-whitelist b/sytest-whitelist index 04b1bbf36..d739313ac 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -592,3 +592,4 @@ Forward extremities remain so even after the next events are populated as outlie If a device list update goes missing, the server resyncs on the next one uploading self-signing key notifies over federation uploading signed devices gets propagated over federation +Device list doesn't change if remote server is down \ No newline at end of file diff --git a/userapi/api/api.go b/userapi/api/api.go index c468fc0f6..5140cc5b8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -18,8 +18,9 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) // UserInternalAPI is the internal API for information about users and devices. @@ -384,6 +385,7 @@ type Device struct { // If the device is for an appservice user, // this is the appservice ID. AppserviceID string + AccountType AccountType } // Account represents a Matrix account on this home server. @@ -392,7 +394,7 @@ type Account struct { Localpart string ServerName gomatrixserverlib.ServerName AppServiceID string - // TODO: Other flags like IsAdmin, IsGuest + AccountType AccountType // TODO: Associations (e.g. with application services) } @@ -448,4 +450,8 @@ const ( AccountTypeUser AccountType = 1 // AccountTypeGuest indicates this is a guest account AccountTypeGuest AccountType = 2 + // AccountTypeAdmin indicates this is an admin account + AccountTypeAdmin AccountType = 3 + // AccountTypeAppService indicates this is an appservice account + AccountTypeAppService AccountType = 4 ) diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go index f3aa037e4..e2207bb53 100644 --- a/userapi/api/api_logintoken.go +++ b/userapi/api/api_logintoken.go @@ -19,6 +19,13 @@ import ( "time" ) +// DefaultLoginTokenLifetime determines how old a valid token may be. +// +// NOTSPEC: The current spec says "SHOULD be limited to around five +// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. +// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). +const DefaultLoginTokenLifetime = 2 * time.Minute + type LoginTokenInternalAPI interface { // PerformLoginTokenCreation creates a new login token and associates it with the provided data. PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 1c306cefa..fdcf796fd 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -21,22 +21,21 @@ import ( "errors" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/userapi/storage" ) type UserInternalAPI struct { - AccountDB accounts.Database - DeviceDB devices.Database + DB storage.Database ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService @@ -54,10 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if req.DataType == "" { return fmt.Errorf("data type must not be empty") } - return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) + return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if req.AccountType == api.AccountTypeGuest { acc, err := a.AccountDB.CreateGuestAccount(ctx) if err != nil { @@ -86,11 +86,18 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P Localpart: req.Localpart, ServerName: a.ServerName, UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + AccountType: req.AccountType, } return nil } - if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if req.AccountType == api.AccountTypeGuest { + res.AccountCreated = true + res.Account = acc + return nil + } + + if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -100,7 +107,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { return err } res.PasswordUpdated = true @@ -113,7 +120,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } @@ -138,12 +145,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) + devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { - err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) } if err != nil { return err @@ -197,7 +204,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil @@ -209,7 +216,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } - dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil @@ -224,7 +231,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } - err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err @@ -262,7 +269,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if domain != a.ServerName { return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) } - prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local) if err != nil { if err == sql.ErrNoRows { return nil @@ -276,7 +283,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil } func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { - profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) + profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit) if err != nil { return err } @@ -285,7 +292,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer } func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { - devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) + devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs) if err != nil { return err } @@ -313,10 +320,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if domain != a.ServerName { return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) } - devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) + devs, err := a.DB.GetDevicesByLocalpart(ctx, local) if err != nil { return err } + res.UserExists = true res.Devices = devs return nil } @@ -331,7 +339,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } if req.DataType != "" { var data json.RawMessage - data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) if err != nil { return err } @@ -349,7 +357,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } return nil } - global, rooms, err := a.AccountDB.GetAccountData(ctx, local) + global, rooms, err := a.DB.GetAccountData(ctx, local) if err != nil { return err } @@ -368,13 +376,22 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } - device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) + device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken) if err != nil { if err == sql.ErrNoRows { return nil } return err } + localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return err + } + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) + if err != nil { + return err + } + device.AccountType = acc.AccountType res.Device = device return nil } @@ -401,6 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // AS dummy device has AS's token. AccessToken: token, AppserviceID: appService.ID, + AccountType: api.AccountTypeAppService, } localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) @@ -410,7 +428,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe if localpart != "" { // AS is masquerading as another user // Verify that the user is registered - account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) + account, err := a.DB.GetAccountByLocalpart(ctx, localpart) // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { @@ -428,7 +446,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { - err := a.AccountDB.DeactivateAccount(ctx, req.Localpart) + err := a.DB.DeactivateAccount(ctx, req.Localpart) res.AccountDeactivated = err == nil return err } @@ -437,7 +455,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { token := util.RandomString(24) - exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID) + exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID) res.Token = api.OpenIDToken{ Token: token, @@ -450,7 +468,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { - openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token) + openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token) if err != nil { return err } @@ -472,7 +490,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } return nil } - exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) + exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) if err != nil { res.Error = fmt.Sprintf("failed to delete backup: %s", err) } @@ -485,7 +503,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } // Create metadata if req.Version == "" { - version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to create backup: %s", err) } @@ -498,7 +516,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } // Update metadata if len(req.Keys.Rooms) == 0 { - err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to update backup: %s", err) } @@ -519,7 +537,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { // you can only upload keys for the CURRENT version - version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") + version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") if err != nil { res.Error = fmt.Sprintf("failed to query version: %s", err) return @@ -547,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform }) } } - count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) + count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) if err != nil { res.Error = fmt.Sprintf("failed to upsert keys: %s", err) return @@ -557,7 +575,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform } func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { - version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { if err == sql.ErrNoRows { @@ -573,14 +591,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB res.Exists = !deleted if !req.ReturnKeys { - res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) + res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) if err != nil { res.Error = fmt.Sprintf("failed to count keys: %s", err) } return } - result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) + result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { res.Error = fmt.Sprintf("failed to query keys: %s", err) return diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index 86ffc58f3..f1bf391e4 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap if domain != a.ServerName { return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) } - tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) + tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) if err != nil { return err } @@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap // PerformLoginTokenDeletion ensures the token doesn't exist. func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { util.GetLogger(ctx).Info("PerformLoginTokenDeletion") - return a.DeviceDB.RemoveLoginToken(ctx, req.Token) + return a.DB.RemoveLoginToken(ctx, req.Token) } // QueryLoginToken returns the data associated with a login token. If // the token is not valid, success is returned, but res.Data == nil. func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { - tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) + tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token) if err != nil { res.Data = nil if err == sql.ErrNoRows { @@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if domain != a.ServerName { return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) } - if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { + if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { res.Data = nil if err == sql.ErrNoRows { return nil diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go deleted file mode 100644 index 14ddd0e2f..000000000 --- a/userapi/storage/accounts/postgres/storage.go +++ /dev/null @@ -1,547 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "strconv" - "time" - - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" - - // Import the postgres database driver. - _ "github.com/lib/pq" -) - -// Database represents an account database -type Database struct { - db *sql.DB - writer sqlutil.Writer - sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - openIDTokens tokenStatements - keyBackupVersions keyBackupVersionStatements - keyBackups keyBackupStatements - serverName gomatrixserverlib.ServerName - bcryptCost int - openIDTokenLifetimeMS int64 -} - -// NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewDummyWriter(), - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, - } - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { - return nil, err - } - m := sqlutil.NewMigrations() - deltas.LoadIsActive(m) - deltas.LoadAddPolicyVersion(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - - if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { - return nil, err - } - if err = d.accounts.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.profiles.prepare(db); err != nil { - return nil, err - } - if err = d.accountDatas.prepare(db); err != nil { - return nil, err - } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } - - return d, nil -} - -// GetAccountByPassword returns the account associated with the given localpart and password. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, -) (*api.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) - if err != nil { - return nil, err - } - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { - return nil, err - } - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// GetProfileByLocalpart returns the profile associated with the given localpart. -// Returns sql.ErrNoRows if no profile exists which matches the given localpart. -func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, -) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) -} - -// SetAvatarURL updates the avatar URL of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, -) error { - return d.profiles.setAvatarURL(ctx, localpart, avatarURL) -} - -// SetDisplayName updates the display name of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, -) error { - return d.profiles.setDisplayName(ctx, localpart, displayName) -} - -// SetPassword sets the account password to the given hash. -func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, -) error { - hash, err := d.hashPassword(plaintextPassword) - if err != nil { - return err - } - return d.accounts.updatePassword(ctx, localpart, hash) -} - -// CreateGuestAccount makes a new guest account and creates an empty profile -// for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "", "") - return err - }) - return acc, err -} - -// CreateAccount makes a new account with the given login name and password, and creates an empty profile -// for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, sqlutil.ErrUserExists. -func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string, -) (acc *api.Account, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion) - return err - }) - return -} - -func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string, -) (*api.Account, error) { - var account *api.Account - var err error - // Generate a password hash if this is not a password-less user - hash := "" - if plaintextPassword != "" { - hash, err = d.hashPassword(plaintextPassword) - if err != nil { - return nil, err - } - } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil { - if sqlutil.IsUniqueConstraintViolationErr(err) { - return nil, sqlutil.ErrUserExists - } - return nil, err - } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { - return nil, err - } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { - return nil, err - } - return account, nil -} - -// SaveAccountData saves new account data for a given user and a given room. -// If the account data is not specific to a room, the room ID should be an empty string -// If an account data already exists for a given set (user, room, data type), it will -// update the corresponding row with the new content -// Returns a SQL error if there was an issue with the insertion/update -func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) - }) -} - -// GetAccountData returns account data related to a given localpart -// If no account data could be found, returns an empty arrays -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global map[string]json.RawMessage, - rooms map[string]map[string]json.RawMessage, - err error, -) { - return d.accountDatas.selectAccountData(ctx, localpart) -} - -// GetAccountDataByType returns account data matching a given -// localpart, room ID and type. -// If no account data could be found, returns nil -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, -) (data json.RawMessage, err error) { - return d.accountDatas.selectAccountDataByType( - ctx, localpart, roomID, dataType, - ) -} - -// GetNewNumericLocalpart generates and returns a new unused numeric localpart -func (d *Database) GetNewNumericLocalpart( - ctx context.Context, -) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) -} - -func (d *Database) hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) - return string(hashBytes), err -} - -// Err3PIDInUse is the error returned when trying to save an association involving -// a third-party identifier which is already associated to a local user. -var Err3PIDInUse = errors.New("this third-party identifier is already in use") - -// SaveThreePIDAssociation saves the association between a third party identifier -// and a local Matrix user (identified by the user's ID's local part). -// If the third-party identifier is already part of an association, returns Err3PIDInUse. -// Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, -) (err error) { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( - ctx, txn, threepid, medium, - ) - if err != nil { - return err - } - - if len(user) > 0 { - return Err3PIDInUse - } - - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) - }) -} - -// RemoveThreePIDAssociation removes the association involving a given third-party -// identifier. -// If no association exists involving this third-party identifier, returns nothing. -// If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation( - ctx context.Context, threepid string, medium string, -) (err error) { - return d.threepids.deleteThreePID(ctx, threepid, medium) -} - -// GetLocalpartForThreePID looks up the localpart associated with a given third-party -// identifier. -// If no association involves the given third-party idenfitier, returns an empty -// string. -// Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID( - ctx context.Context, threepid string, medium string, -) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) -} - -// GetThreePIDsForLocalpart looks up the third-party identifiers associated with -// a given local user. -// If no association is known for this user, returns an empty slice. -// Returns an error if there was an issue talking to the database. -func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, -) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) -} - -// CheckAccountAvailability checks if the username/localpart is already present -// in the database. -// If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) - if err == sql.ErrNoRows { - return true, nil - } - return false, err -} - -// GetAccountByLocalpart returns the account associated with the given localpart. -// This function assumes the request is authenticated or the account data is used only internally. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*api.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// SearchProfiles returns all profiles where the provided localpart or display name -// match any part of the profiles in the database. -func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, -) ([]authtypes.Profile, error) { - return d.profiles.selectProfilesBySearch(ctx, searchString, limit) -} - -// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.accounts.deactivateAccount(ctx, localpart) -} - -// CreateOpenIDToken persists a new token that was issued through OpenID Connect -func (d *Database) CreateOpenIDToken( - ctx context.Context, - token, localpart string, -) (int64, error) { - expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) - }) - return expiresAtMS, err -} - -// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token -func (d *Database) GetOpenIDTokenAttributes( - ctx context.Context, - token string, -) (*api.OpenIDTokenAttributes, error) { - return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) -} - -func (d *Database) CreateKeyBackup( - ctx context.Context, userID, algorithm string, authData json.RawMessage, -) (version string, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") - return err - }) - return -} - -func (d *Database) UpdateKeyBackupAuthData( - ctx context.Context, userID, version string, authData json.RawMessage, -) (err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) - }) - return -} - -func (d *Database) DeleteKeyBackup( - ctx context.Context, userID, version string, -) (exists bool, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetKeyBackup( - ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetBackupKeys( - ctx context.Context, version, userID, filterRoomID, filterSessionID string, -) (result map[string]map[string]api.KeyBackupSession, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if filterSessionID != "" { - result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) - return err - } - if filterRoomID != "" { - result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) - return err - } - result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) CountBackupKeys( - ctx context.Context, version, userID string, -) (count int64, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - return nil - }) - return -} - -// nolint:nakedret -func (d *Database) UpsertBackupKeys( - ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, -) (count int64, etag string, err error) { - // wrap the following logic in a txn to ensure we atomically upload keys - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - if err != nil { - return err - } - if deleted { - return fmt.Errorf("backup was deleted") - } - // pull out all keys for this (user_id, version) - existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) - if err != nil { - return err - } - - changed := false - // loop over all the new keys (which should be smaller than the set of backed up keys) - for _, newKey := range uploads { - // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. - existingRoom := existingKeys[newKey.RoomID] - if existingRoom != nil { - existingSession, ok := existingRoom[newKey.SessionID] - if ok { - if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { - err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) - } - } - // if we shouldn't replace the key we do nothing with it - continue - } - } - // if we're here, either the room or session are new, either way, we insert - err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) - } - } - - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - if changed { - // update the etag - var newETag string - if oldETag == "" { - newETag = "1" - } else { - oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) - if err != nil { - return fmt.Errorf("failed to parse old etag: %s", err) - } - newETag = strconv.FormatInt(oldETagInt+1, 10) - } - etag = newETag - return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) - } else { - etag = oldETag - } - return nil - }) - return -} - -// GetPrivacyPolicy returns the accepted privacy policy version, if any. -func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart) - return err - }) - return -} - -// GetOutdatedPolicy queries all users which didn't accept the current policy version. -func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion) - return err - }) - return -} - -// UpdatePolicyVersion sets the accepted policy_version for a user. -func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart) - }) - return -} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go deleted file mode 100644 index 8ff91cf1c..000000000 --- a/userapi/storage/devices/interface.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devices - -import ( - "context" - - "github.com/matrix-org/dendrite/userapi/api" -) - -type Database interface { - GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) - GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - // CreateDevice makes a new device associated with the given user ID localpart. - // If there is already a device with the same device ID for this user, that access token will be revoked - // and replaced with the given accessToken. If the given accessToken is already in use for another device, - // an error will be returned. - // If no device ID is given one is generated. - // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error - RemoveDevice(ctx context.Context, deviceID, localpart string) error - RemoveDevices(ctx context.Context, localpart string, devices []string) error - // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) - - // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. - CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) - - // RemoveLoginToken removes the named token (and may clean up other expired tokens). - RemoveLoginToken(ctx context.Context, token string) error - - // GetLoginTokenDataByToken returns the data associated with the given token. - // May return sql.ErrNoRows. - GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) -} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go deleted file mode 100644 index fd9d513f1..000000000 --- a/userapi/storage/devices/postgres/storage.go +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas" - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - loginTokenByteLength = 32 -) - -// Database represents a device database. -type Database struct { - db *sql.DB - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration -} - -// NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - var d devicesStatements - var lt loginTokenStatements - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err - } - if err = lt.execSchema(db); err != nil { - return nil, err - } - - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - - if err = d.prepare(db, serverName); err != nil { - return nil, err - } - if err = lt.prepare(db); err != nil { - return nil, err - } - - return &Database{db, d, lt, loginTokenLifetime}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() - if err != nil { - return nil, err - } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) - if err != nil { - return nil, err - } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) -} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go deleted file mode 100644 index 6e90413be..000000000 --- a/userapi/storage/devices/sqlite3/storage.go +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - - loginTokenByteLength = 32 -) - -// Database represents a device database. -type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration -} - -// NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - writer := sqlutil.NewExclusiveWriter() - var d devicesStatements - var lt loginTokenStatements - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err - } - if err = lt.execSchema(db); err != nil { - return nil, err - } - - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - if err = d.prepare(db, writer, serverName); err != nil { - return nil, err - } - if err = lt.prepare(db); err != nil { - return nil, err - } - return &Database{db, writer, d, lt, loginTokenLifetime}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() - if err != nil { - return nil, err - } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) - if err != nil { - return nil, err - } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) -} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go deleted file mode 100644 index 15cf8150c..000000000 --- a/userapi/storage/devices/storage.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package devices - -import ( - "fmt" - "time" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" - "github.com/matrix-org/gomatrixserverlib" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters. loginTokenLifetime determines how long a -// login token from CreateLoginToken is valid. -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go deleted file mode 100644 index 3de7880b9..000000000 --- a/userapi/storage/devices/storage_wasm.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devices - -import ( - "fmt" - "time" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" - "github.com/matrix-org/gomatrixserverlib" -) - -func NewDatabase( - dbProperties *config.DatabaseOptions, - serverName gomatrixserverlib.ServerName, - loginTokenLifetime time.Duration, -) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/interface.go similarity index 66% rename from userapi/storage/accounts/interface.go rename to userapi/storage/interface.go index 81224af1a..9e4ff4b4b 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/interface.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package storage import ( "context" @@ -32,8 +32,7 @@ type Database interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string) (*api.Account, error) - CreateGuestAccount(ctx context.Context) (*api.Account, error) + CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, policyVersion string, accountType api.AccountType) (*api.Account, error) SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given @@ -64,6 +63,35 @@ type Database interface { UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) + + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) + GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) + // CreateDevice makes a new device associated with the given user ID localpart. + // If there is already a device with the same device ID for this user, that access token will be revoked + // and replaced with the given accessToken. If the given accessToken is already in use for another device, + // an error will be returned. + // If no device ID is given one is generated. + // Returns the device on success. + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error + RemoveDevice(ctx context.Context, deviceID, localpart string) error + RemoveDevices(ctx context.Context, localpart string, devices []string) error + // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. + RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + + // CreateLoginToken generates a token, stores and returns it. The lifetime is + // determined by the loginTokenLifetime given to the Database constructor. + CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) + + // RemoveLoginToken removes the named token (and may clean up other expired tokens). + RemoveLoginToken(ctx context.Context, token string) error + + // GetLoginTokenDataByToken returns the data associated with the given token. + // May return sql.ErrNoRows. + GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go similarity index 89% rename from userapi/storage/accounts/postgres/account_data_table.go rename to userapi/storage/postgres/account_data_table.go index 8ba890e75..67113367b 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const accountDataSchema = ` @@ -56,19 +57,20 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(accountDataSchema) +func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { + s := &accountDataStatements{} + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, }.Prepare(db) } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) @@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData( return } -func (s *accountDataStatements) selectAccountData( +func (s *accountDataStatements) SelectAccountData( ctx context.Context, localpart string, ) ( /* global */ map[string]json.RawMessage, @@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData( return global, rooms, rows.Err() } -func (s *accountDataStatements) selectAccountDataByType( +func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go similarity index 85% rename from userapi/storage/accounts/postgres/accounts_table.go rename to userapi/storage/postgres/accounts_table.go index a74d04ad4..f276e2038 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -40,17 +42,19 @@ CREATE TABLE IF NOT EXISTS account_accounts ( appservice_id TEXT, -- If the account is currently active is_deactivated BOOLEAN DEFAULT FALSE, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type SMALLINT NOT NULL, -- The policy version this user has accepted policy_version TEXT -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); -- Create sequence for autogenerated numeric usernames CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -59,7 +63,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" @@ -89,14 +93,15 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -112,17 +117,17 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, +func (s *accountsStatements) InsertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error - if appserviceID == "" { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, policyVersion) + if accountType != api.AccountTypeAppService { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, "") + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "") } if err != nil { return nil, err @@ -133,38 +138,39 @@ func (s *accountsStatements) insertAccount( UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, AppServiceID: appserviceID, + AccountType: accountType, }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -181,7 +187,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt diff --git a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go similarity index 93% rename from userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go rename to userapi/storage/postgres/deltas/20200929203058_is_active.go index 1d50d3d5a..c38759975 100644 --- a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go +++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go @@ -4,12 +4,14 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) + goose.AddMigration(UpAddAccountType, DownAddAccountType) goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion) } diff --git a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go similarity index 89% rename from userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go rename to userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go index 290f854c8..1bbb0a9d3 100644 --- a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go @@ -5,13 +5,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/pressly/goose" ) -func LoadFromGoose() { - goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) -} - func LoadLastSeenTSIP(m *sqlutil.Migrations) { m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) } diff --git a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go new file mode 100644 index 000000000..2fae00cb9 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go @@ -0,0 +1,34 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadAddAccountType(m *sqlutil.Migrations) { + m.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func UpAddAccountType(tx *sql.Tx) error { + // initially set every account to useraccount, change appservice and guest accounts afterwards + // (user = 1, guest = 2, admin = 3, appservice = 4) + _, err := tx.Exec(`ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; +UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; +ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, + ) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddAccountType(tx *sql.Tx) error { + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN account_type;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/accounts/postgres/deltas/2022021414375800_add_policy_version.go b/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go similarity index 100% rename from userapi/storage/accounts/postgres/deltas/2022021414375800_add_policy_version.go rename to userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go similarity index 85% rename from userapi/storage/devices/postgres/devices_table.go rename to userapi/storage/postgres/devices_table.go index 7de9f5f9e..7bc5dc69b 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -111,50 +112,32 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { +func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + serverName: serverName, + } _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return + if err != nil { + return nil, err } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.deleteDevicesStmt, deleteDevicesSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -176,7 +159,7 @@ func (s *devicesStatements) insertDevice( } // deleteDevice removes a single device by id and user localpart. -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -186,7 +169,7 @@ func (s *devicesStatements) deleteDevice( // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) @@ -196,7 +179,7 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -204,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -212,7 +195,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -228,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -245,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) if err != nil { return nil, err @@ -268,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -310,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go similarity index 91% rename from userapi/storage/accounts/postgres/key_backup_table.go rename to userapi/storage/postgres/key_backup_table.go index c1402d4d2..ac0e80617 100644 --- a/userapi/storage/accounts/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupTableSchema = ` @@ -72,12 +73,13 @@ type keyBackupStatements struct { selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupTableSchema) +func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { + s := &keyBackupStatements{} + _, err := db.Exec(keyBackupTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, @@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s keyBackupStatements) countKeys( +func (s keyBackupStatements) CountKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (count int64, err error) { err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) return } -func (s *keyBackupStatements) insertBackupKey( +func (s *keyBackupStatements) InsertBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( @@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey( return } -func (s *keyBackupStatements) updateBackupKey( +func (s *keyBackupStatements) UpdateBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( @@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey( return } -func (s *keyBackupStatements) selectKeys( +func (s *keyBackupStatements) SelectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) @@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomID( +func (s *keyBackupStatements) SelectKeysByRoomID( ctx context.Context, txn *sql.Tx, userID, version, roomID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) @@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( +func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID( ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go similarity index 90% rename from userapi/storage/accounts/postgres/key_backup_version_table.go rename to userapi/storage/postgres/key_backup_version_table.go index d73447b49..e78e4cd51 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/postgres/key_backup_version_table.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupVersionTableSchema = ` @@ -69,12 +70,13 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupVersionTableSchema) +func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) { + s := &keyBackupVersionStatements{} + _, err := db.Exec(keyBackupVersionTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, @@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *keyBackupVersionStatements) insertKeyBackup( +func (s *keyBackupVersionStatements) InsertKeyBackup( ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 @@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return strconv.FormatInt(versionInt, 10), err } -func (s *keyBackupVersionStatements) updateKeyBackupAuthData( +func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData( ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } -func (s *keyBackupVersionStatements) updateKeyBackupETag( +func (s *keyBackupVersionStatements) UpdateKeyBackupETag( ctx context.Context, txn *sql.Tx, userID, version, etag string, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( return err } -func (s *keyBackupVersionStatements) deleteKeyBackup( +func (s *keyBackupVersionStatements) DeleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( return ra == 1, nil } -func (s *keyBackupVersionStatements) selectKeyBackup( +func (s *keyBackupVersionStatements) SelectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go similarity index 63% rename from userapi/storage/devices/postgres/logintoken_table.go rename to userapi/storage/postgres/logintoken_table.go index f601fc7db..4de96f839 100644 --- a/userapi/storage/devices/postgres/logintoken_table.go +++ b/userapi/storage/postgres/logintoken_table.go @@ -21,18 +21,11 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/util" ) -type loginTokenStatements struct { - insertStmt *sql.Stmt - deleteStmt *sql.Stmt - selectByTokenStmt *sql.Stmt -} - -// execSchema ensures tables and indices exist. -func (s *loginTokenStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(` +const loginTokenSchema = ` CREATE TABLE IF NOT EXISTS login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, @@ -45,21 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens ( -- This index allows efficient garbage collection of expired tokens. CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); -`) - return err +` + +const insertLoginTokenSQL = "" + + "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + +const deleteLoginTokenSQL = "" + + "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + +const selectLoginTokenSQL = "" + + "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt } -// prepare runs statement preparation. -func (s *loginTokenStatements) prepare(db *sql.DB) error { - return sqlutil.StatementList{ - {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, - {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, - {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, +func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { + s := &loginTokenStatements{} + _, err := db.Exec(loginTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertLoginTokenSQL}, + {&s.deleteStmt, deleteLoginTokenSQL}, + {&s.selectStmt, selectLoginTokenSQL}, }.Prepare(db) } // insert adds an already generated token to the database. -func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { +func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { stmt := sqlutil.TxStmt(txn, s.insertStmt) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) return err @@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // The login_tokens_expiration_idx index should make that efficient. -func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { +func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) if err != nil { @@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t } // selectByToken returns the data associated with the given token. May return sql.ErrNoRows. -func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { +func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) { var data api.LoginTokenData - err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) if err != nil { return nil, err } diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go similarity index 74% rename from userapi/storage/accounts/postgres/openid_table.go rename to userapi/storage/postgres/openid_table.go index 190d141b7..29c3ddcb4 100644 --- a/userapi/storage/accounts/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ); ` -const insertTokenSQL = "" + +const insertOpenIDTokenSQL = "" + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" -const selectTokenSQL = "" + +const selectOpenIDTokenSQL = "" + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" -type tokenStatements struct { +type openIDTokenStatements struct { insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt serverName gomatrixserverlib.ServerName } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return +func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { + s := &openIDTokenStatements{ + serverName: serverName, } - s.serverName = server - return sqlutil.StatementList{ - {&s.insertTokenStmt, insertTokenSQL}, - {&s.selectTokenStmt, selectTokenSQL}, + _, err := db.Exec(openIDTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertTokenStmt, insertOpenIDTokenSQL}, + {&s.selectTokenStmt, selectOpenIDTokenSQL}, }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. // Returns new token, otherwise returns error if the token already exists. -func (s *tokenStatements) insertToken( +func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, token, localpart string, @@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken( // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // Returns the existing token's attributes, or err if no token is found -func (s *tokenStatements) selectOpenIDTokenAtrributes( +func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go similarity index 85% rename from userapi/storage/accounts/postgres/profile_table.go rename to userapi/storage/postgres/profile_table.go index 9313864be..32a4b5506 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const profilesSchema = ` @@ -59,12 +60,13 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(profilesSchema) +func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) { + s := &profilesStatements{} + _, err := db.Exec(profilesSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertProfileStmt, insertProfileSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL}, @@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *profilesStatements) insertProfile( +func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return } -func (s *profilesStatements) selectProfileByLocalpart( +func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { var profile authtypes.Profile @@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart( return &profile, nil } -func (s *profilesStatements) setAvatarURL( - ctx context.Context, localpart string, avatarURL string, +func (s *profilesStatements) SetAvatarURL( + ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) return } -func (s *profilesStatements) setDisplayName( - ctx context.Context, localpart string, displayName string, +func (s *profilesStatements) SetDisplayName( + ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } -func (s *profilesStatements) selectProfilesBySearch( +func (s *profilesStatements) SelectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go new file mode 100644 index 000000000..ac5c59b81 --- /dev/null +++ b/userapi/storage/postgres/storage.go @@ -0,0 +1,105 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/shared" + + // Import the postgres database driver. + _ "github.com/lib/pq" +) + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + + m := sqlutil.NewMigrations() + if _, err = db.Exec(accountsSchema); err != nil { + // do this so that the migration can and we don't fail on + // preparing statements for columns that don't exist yet + return nil, err + } + deltas.LoadIsActive(m) + //deltas.LoadLastSeenTSIP(m) + deltas.LoadAddAccountType(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + + accountDataTable, err := NewPostgresAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) + } + accountsTable, err := NewPostgresAccountsTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) + } + devicesTable, err := NewPostgresDevicesTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) + } + keyBackupTable, err := NewPostgresKeyBackupTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err) + } + keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err) + } + loginTokenTable, err := NewPostgresLoginTokenTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err) + } + openIDTable, err := NewPostgresOpenIDTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err) + } + profilesTable, err := NewPostgresProfilesTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err) + } + threePIDTable, err := NewPostgresThreePIDTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) + } + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + ServerName: serverName, + DB: db, + Writer: sqlutil.NewDummyWriter(), + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil +} diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go similarity index 83% rename from userapi/storage/accounts/postgres/threepid_table.go rename to userapi/storage/postgres/threepid_table.go index 9280fc87c..63c08d61f 100644 --- a/userapi/storage/accounts/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -58,12 +59,13 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(threepidSchema) +func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{} + _, err := db.Exec(threepidSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID( return } -func (s *threepidStatements) deleteThreePID( - ctx context.Context, threepid string, medium string) (err error) { - _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) +func (s *threepidStatements) DeleteThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/shared/storage.go similarity index 53% rename from userapi/storage/accounts/sqlite3/storage.go rename to userapi/storage/shared/storage.go index 0fcd65347..1d48315cf 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/shared/storage.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sqlite3 +package shared import ( "context" @@ -21,109 +21,55 @@ import ( "errors" "fmt" "strconv" - "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) // Database represents an account database type Database struct { - db *sql.DB - writer sqlutil.Writer - - sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - openIDTokens tokenStatements - keyBackupVersions keyBackupVersionStatements - keyBackups keyBackupStatements - serverName gomatrixserverlib.ServerName - bcryptCost int - openIDTokenLifetimeMS int64 - - accountsMu sync.Mutex - profilesMu sync.Mutex - accountDatasMu sync.Mutex - threepidsMu sync.Mutex + DB *sql.DB + Writer sqlutil.Writer + Accounts tables.AccountsTable + Profiles tables.ProfileTable + AccountDatas tables.AccountDataTable + ThreePIDs tables.ThreePIDTable + OpenIDTokens tables.OpenIDTable + KeyBackups tables.KeyBackupTable + KeyBackupVersions tables.KeyBackupVersionTable + Devices tables.DevicesTable + LoginTokens tables.LoginTokenTable + LoginTokenLifetime time.Duration + ServerName gomatrixserverlib.ServerName + BcryptCost int + OpenIDTokenLifetimeMS int64 } -// NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewExclusiveWriter(), - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, - } - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { - return nil, err - } - m := sqlutil.NewMigrations() - deltas.LoadIsActive(m) - deltas.LoadAddPolicyVersion(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - - partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, d.writer, "account"); err != nil { - return nil, err - } - if err = d.accounts.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.profiles.prepare(db); err != nil { - return nil, err - } - if err = d.accountDatas.prepare(db); err != nil { - return nil, err - } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } - - return d, nil -} +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, ) (*api.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) + hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) if err != nil { return nil, err } if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { return nil, err } - return d.accounts.selectAccountByLocalpart(ctx, localpart) + return d.Accounts.SelectAccountByLocalpart(ctx, localpart) } // GetProfileByLocalpart returns the profile associated with the given localpart. @@ -131,7 +77,7 @@ func (d *Database) GetAccountByPassword( func (d *Database) GetProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) + return d.Profiles.SelectProfileByLocalpart(ctx, localpart) } // SetAvatarURL updates the avatar URL of the profile associated with the given @@ -139,10 +85,8 @@ func (d *Database) GetProfileByLocalpart( func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, ) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) }) } @@ -151,10 +95,8 @@ func (d *Database) SetAvatarURL( func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, ) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setDisplayName(ctx, txn, localpart, displayName) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) }) } @@ -166,53 +108,30 @@ func (d *Database) SetPassword( if err != nil { return err } - return d.writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.accounts.updatePassword(ctx, localpart, hash) + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Accounts.UpdatePassword(ctx, localpart, hash) }) } -// CreateGuestAccount makes a new guest account and creates an empty profile -// for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { - // We need to lock so we sequentially create numeric localparts. If we don't, two calls to - // this function will cause the same number to be selected and one will fail with 'database is locked' - // when the first txn upgrades to a write txn. We also need to lock the account creation else we can - // race with CreateAccount - // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "", "") - return err - }) - return acc, err -} - // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string, + ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType, ) (acc *api.Account, err error) { - // Create one account at a time else we can get 'database is locked'. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // For guest accounts, we create a new numeric local part + if accountType == api.AccountTypeGuest { + var numLocalpart int64 + numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart = strconv.FormatInt(numLocalpart, 10) + plaintextPassword = "" + appserviceID = "" + } + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion, accountType) return err }) return @@ -221,7 +140,7 @@ func (d *Database) CreateAccount( // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string, + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType, ) (*api.Account, error) { var err error var account *api.Account @@ -233,13 +152,13 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil { + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion, accountType); err != nil { return nil, sqlutil.ErrUserExists } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { + if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { return nil, err } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -261,10 +180,8 @@ func (d *Database) createAccount( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { - d.accountDatasMu.Lock() - defer d.accountDatasMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -276,7 +193,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( rooms map[string]map[string]json.RawMessage, err error, ) { - return d.accountDatas.selectAccountData(ctx, localpart) + return d.AccountDatas.SelectAccountData(ctx, localpart) } // GetAccountDataByType returns account data matching a given @@ -286,7 +203,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { - return d.accountDatas.selectAccountDataByType( + return d.AccountDatas.SelectAccountDataByType( ctx, localpart, roomID, dataType, ) } @@ -295,11 +212,11 @@ func (d *Database) GetAccountDataByType( func (d *Database) GetNewNumericLocalpart( ctx context.Context, ) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) + return d.Accounts.SelectNewNumericLocalpart(ctx, nil) } func (d *Database) hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost) return string(hashBytes), err } @@ -314,10 +231,8 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + user, err := d.ThreePIDs.SelectLocalpartForThreePID( ctx, txn, threepid, medium, ) if err != nil { @@ -328,7 +243,7 @@ func (d *Database) SaveThreePIDAssociation( return Err3PIDInUse } - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) }) } @@ -339,10 +254,8 @@ func (d *Database) SaveThreePIDAssociation( func (d *Database) RemoveThreePIDAssociation( ctx context.Context, threepid string, medium string, ) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.threepids.deleteThreePID(ctx, txn, threepid, medium) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium) }) } @@ -354,7 +267,7 @@ func (d *Database) RemoveThreePIDAssociation( func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, ) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) + return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) } // GetThreePIDsForLocalpart looks up the third-party identifiers associated with @@ -364,14 +277,14 @@ func (d *Database) GetLocalpartForThreePID( func (d *Database) GetThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) + return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) } // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) if err == sql.ErrNoRows { return true, nil } @@ -383,20 +296,20 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, ) (*api.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) + return d.Accounts.SelectAccountByLocalpart(ctx, localpart) } // SearchProfiles returns all profiles where the provided localpart or display name // match any part of the profiles in the database. func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { - return d.profiles.selectProfilesBySearch(ctx, searchString, limit) + return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit) } // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.accounts.deactivateAccount(ctx, localpart) + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Accounts.DeactivateAccount(ctx, localpart) }) } @@ -405,9 +318,9 @@ func (d *Database) CreateOpenIDToken( ctx context.Context, token, localpart string, ) (int64, error) { - expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) }) return expiresAtMS, err } @@ -417,14 +330,14 @@ func (d *Database) GetOpenIDTokenAttributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { - return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) + return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token) } func (d *Database) CreateKeyBackup( ctx context.Context, userID, algorithm string, authData json.RawMessage, ) (version string, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "") return err }) return @@ -433,8 +346,8 @@ func (d *Database) CreateKeyBackup( func (d *Database) UpdateKeyBackupAuthData( ctx context.Context, userID, version string, authData json.RawMessage, ) (err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData) }) return } @@ -442,8 +355,8 @@ func (d *Database) UpdateKeyBackupAuthData( func (d *Database) DeleteKeyBackup( ctx context.Context, userID, version string, ) (exists bool, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version) return err }) return @@ -452,8 +365,8 @@ func (d *Database) DeleteKeyBackup( func (d *Database) GetKeyBackup( ctx context.Context, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version) return err }) return @@ -462,16 +375,16 @@ func (d *Database) GetKeyBackup( func (d *Database) GetBackupKeys( ctx context.Context, version, userID, filterRoomID, filterSessionID string, ) (result map[string]map[string]api.KeyBackupSession, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if filterSessionID != "" { - result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) return err } if filterRoomID != "" { - result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID) return err } - result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version) return err }) return @@ -480,8 +393,8 @@ func (d *Database) GetBackupKeys( func (d *Database) CountBackupKeys( ctx context.Context, version, userID string, ) (count int64, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version) if err != nil { return err } @@ -495,8 +408,8 @@ func (d *Database) UpsertBackupKeys( ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, ) (count int64, etag string, err error) { // wrap the following logic in a txn to ensure we atomically upload keys - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version) if err != nil { return err } @@ -504,7 +417,7 @@ func (d *Database) UpsertBackupKeys( return fmt.Errorf("backup was deleted") } // pull out all keys for this (user_id, version) - existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version) if err != nil { return err } @@ -518,10 +431,10 @@ func (d *Database) UpsertBackupKeys( existingSession, ok := existingRoom[newKey.SessionID] if ok { if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { - err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { - return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) + return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err) } } // if we shouldn't replace the key we do nothing with it @@ -529,14 +442,14 @@ func (d *Database) UpsertBackupKeys( } } // if we're here, either the room or session are new, either way, we insert - err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { - return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) + return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err) } } - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version) if err != nil { return err } @@ -553,7 +466,7 @@ func (d *Database) UpsertBackupKeys( newETag = strconv.FormatInt(oldETagInt+1, 10) } etag = newETag - return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag) } else { etag = oldETag } @@ -563,10 +476,203 @@ func (d *Database) UpsertBackupKeys( return } +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.Devices.SelectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.Devices.SelectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.LoginTokenLifetime), + } + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.LoginTokens.DeleteLoginToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.LoginTokens.SelectLoginToken(ctx, token) +} + // GetPrivacyPolicy returns the accepted privacy policy version, if any. func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart) + policyVersion, err = d.Accounts.selectPrivacyPolicy(ctx, txn, localpart) return err }) return diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go similarity index 88% rename from userapi/storage/accounts/sqlite3/account_data_table.go rename to userapi/storage/sqlite3/account_data_table.go index 871f996e0..cfd8568a9 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const accountDataSchema = ` @@ -56,27 +57,29 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(accountDataSchema) - if err != nil { - return +func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { + s := &accountDataStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(accountDataSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, }.Prepare(db) } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) error { _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return err } -func (s *accountDataStatements) selectAccountData( +func (s *accountDataStatements) SelectAccountData( ctx context.Context, localpart string, ) ( /* global */ map[string]json.RawMessage, @@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData( return global, rooms, nil } -func (s *accountDataStatements) selectAccountDataByType( +func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go similarity index 85% rename from userapi/storage/accounts/sqlite3/accounts_table.go rename to userapi/storage/sqlite3/accounts_table.go index 95ebb022e..483736a80 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -40,15 +42,17 @@ CREATE TABLE IF NOT EXISTS account_accounts ( appservice_id TEXT, -- If the account is currently active is_deactivated BOOLEAN DEFAULT 0, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type INTEGER NOT NULL, -- The policy version this user has accepted policy_version TEXT -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -57,7 +61,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" @@ -88,15 +92,16 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + db: db, + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.db = db - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -112,17 +117,17 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, +func (s *accountsStatements) InsertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error - if appserviceID == "" { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, policyVersion) + if accountType != api.AccountTypeAppService { + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, "") + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "") } if err != nil { return nil, err @@ -136,35 +141,35 @@ func (s *accountsStatements) insertAccount( }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -181,7 +186,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt diff --git a/userapi/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/sqlite3/constraint_wasm.go similarity index 100% rename from userapi/storage/accounts/sqlite3/constraint_wasm.go rename to userapi/storage/sqlite3/constraint_wasm.go diff --git a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go similarity index 96% rename from userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go rename to userapi/storage/sqlite3/deltas/20200929203058_is_active.go index 0b95b4996..24ef265e7 100644 --- a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -4,12 +4,14 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) + goose.AddMigration(UpAddAccountType, DownAddAccountType) goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion) } diff --git a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go similarity index 94% rename from userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go rename to userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index 262098265..ebf908001 100644 --- a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -5,13 +5,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/pressly/goose" ) -func LoadFromGoose() { - goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) -} - func LoadLastSeenTSIP(m *sqlutil.Migrations) { m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) } diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go new file mode 100644 index 000000000..9b058dedd --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -0,0 +1,54 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func init() { + goose.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func LoadAddAccountType(m *sqlutil.Migrations) { + m.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func UpAddAccountType(tx *sql.Tx) error { + // initially set every account to useraccount, change appservice and guest accounts afterwards + // (user = 1, guest = 2, admin = 3, appservice = 4) + _, err := tx.Exec(`ALTER TABLE account_accounts RENAME TO account_accounts_tmp; +CREATE TABLE account_accounts ( + localpart TEXT NOT NULL PRIMARY KEY, + created_ts BIGINT NOT NULL, + password_hash TEXT, + appservice_id TEXT, + is_deactivated BOOLEAN DEFAULT 0, + account_type INTEGER NOT NULL +); +INSERT + INTO account_accounts ( + localpart, created_ts, password_hash, appservice_id, account_type + ) SELECT + localpart, created_ts, password_hash, appservice_id, 1 + FROM account_accounts_tmp +; +UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; +DROP TABLE account_accounts_tmp;`) + if err != nil { + return fmt.Errorf("failed to add column: %w", err) + } + return nil +} + +func DownAddAccountType(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE account_accounts DROP COLUMN account_type;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/accounts/sqlite3/deltas/2022021414375800_add_policy_version.go b/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go similarity index 100% rename from userapi/storage/accounts/sqlite3/deltas/2022021414375800_add_policy_version.go rename to userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go similarity index 82% rename from userapi/storage/devices/sqlite3/devices_table.go rename to userapi/storage/sqlite3/devices_table.go index 955d8ac7f..423640e90 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" @@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" + type devicesStatements struct { db *sql.DB - writer sqlutil.Writer insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -98,52 +98,33 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { +func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + db: db, + serverName: serverName, + } _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { - s.db = db - s.writer = writer - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return + if err != nil { + return nil, err } - if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { - return - } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDevicesCountStmt, selectDevicesCountSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -169,7 +150,7 @@ func (s *devicesStatements) insertDevice( }, nil } -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -177,7 +158,7 @@ func (s *devicesStatements) deleteDevice( return err } -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) @@ -195,7 +176,7 @@ func (s *devicesStatements) deleteDevices( return err } -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -203,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -211,7 +192,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -227,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -244,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -285,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, nil } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) iDeviceIDs := make([]interface{}, len(deviceIDs)) for i := range deviceIDs { @@ -314,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go similarity index 91% rename from userapi/storage/accounts/sqlite3/key_backup_table.go rename to userapi/storage/sqlite3/key_backup_table.go index 837d38cf1..81726edf9 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupTableSchema = ` @@ -72,12 +73,13 @@ type keyBackupStatements struct { selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupTableSchema) +func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { + s := &keyBackupStatements{} + _, err := db.Exec(keyBackupTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, @@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s keyBackupStatements) countKeys( +func (s keyBackupStatements) CountKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (count int64, err error) { err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) return } -func (s *keyBackupStatements) insertBackupKey( +func (s *keyBackupStatements) InsertBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( @@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey( return } -func (s *keyBackupStatements) updateBackupKey( +func (s *keyBackupStatements) UpdateBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( @@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey( return } -func (s *keyBackupStatements) selectKeys( +func (s *keyBackupStatements) SelectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) @@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomID( +func (s *keyBackupStatements) SelectKeysByRoomID( ctx context.Context, txn *sql.Tx, userID, version, roomID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) @@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( +func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID( ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go similarity index 89% rename from userapi/storage/accounts/sqlite3/key_backup_version_table.go rename to userapi/storage/sqlite3/key_backup_version_table.go index 4211ed0f1..e85e6f08b 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go +++ b/userapi/storage/sqlite3/key_backup_version_table.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupVersionTableSchema = ` @@ -67,12 +68,13 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupVersionTableSchema) +func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) { + s := &keyBackupVersionStatements{} + _, err := db.Exec(keyBackupVersionTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, @@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *keyBackupVersionStatements) insertKeyBackup( +func (s *keyBackupVersionStatements) InsertKeyBackup( ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 @@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return strconv.FormatInt(versionInt, 10), err } -func (s *keyBackupVersionStatements) updateKeyBackupAuthData( +func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData( ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } -func (s *keyBackupVersionStatements) updateKeyBackupETag( +func (s *keyBackupVersionStatements) UpdateKeyBackupETag( ctx context.Context, txn *sql.Tx, userID, version, etag string, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( return err } -func (s *keyBackupVersionStatements) deleteKeyBackup( +func (s *keyBackupVersionStatements) DeleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( return ra == 1, nil } -func (s *keyBackupVersionStatements) selectKeyBackup( +func (s *keyBackupVersionStatements) SelectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go similarity index 63% rename from userapi/storage/devices/sqlite3/logintoken_table.go rename to userapi/storage/sqlite3/logintoken_table.go index 75ef272f8..78d42029a 100644 --- a/userapi/storage/devices/sqlite3/logintoken_table.go +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -21,18 +21,17 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/util" ) type loginTokenStatements struct { - insertStmt *sql.Stmt - deleteStmt *sql.Stmt - selectByTokenStmt *sql.Stmt + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt } -// execSchema ensures tables and indices exist. -func (s *loginTokenStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(` +const loginTokenSchema = ` CREATE TABLE IF NOT EXISTS login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, @@ -45,21 +44,32 @@ CREATE TABLE IF NOT EXISTS login_tokens ( -- This index allows efficient garbage collection of expired tokens. CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); -`) - return err -} +` -// prepare runs statement preparation. -func (s *loginTokenStatements) prepare(db *sql.DB) error { - return sqlutil.StatementList{ - {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, - {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, - {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, +const insertLoginTokenSQL = "" + + "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + +const deleteLoginTokenSQL = "" + + "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + +const selectLoginTokenSQL = "" + + "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + +func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { + s := &loginTokenStatements{} + _, err := db.Exec(loginTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertLoginTokenSQL}, + {&s.deleteStmt, deleteLoginTokenSQL}, + {&s.selectStmt, selectLoginTokenSQL}, }.Prepare(db) } // insert adds an already generated token to the database. -func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { +func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { stmt := sqlutil.TxStmt(txn, s.insertStmt) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) return err @@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // The login_tokens_expiration_idx index should make that efficient. -func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { +func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) if err != nil { @@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t } // selectByToken returns the data associated with the given token. May return sql.ErrNoRows. -func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { +func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) { var data api.LoginTokenData - err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) if err != nil { return nil, err } diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go similarity index 74% rename from userapi/storage/accounts/sqlite3/openid_table.go rename to userapi/storage/sqlite3/openid_table.go index 98c0488b1..d6090e0da 100644 --- a/userapi/storage/accounts/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ); ` -const insertTokenSQL = "" + +const insertOpenIDTokenSQL = "" + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" -const selectTokenSQL = "" + +const selectOpenIDTokenSQL = "" + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" -type tokenStatements struct { +type openIDTokenStatements struct { db *sql.DB insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt serverName gomatrixserverlib.ServerName } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.db = db - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return err +func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { + s := &openIDTokenStatements{ + db: db, + serverName: serverName, } - s.serverName = server - return sqlutil.StatementList{ - {&s.insertTokenStmt, insertTokenSQL}, - {&s.selectTokenStmt, selectTokenSQL}, + _, err := db.Exec(openIDTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertTokenStmt, insertOpenIDTokenSQL}, + {&s.selectTokenStmt, selectOpenIDTokenSQL}, }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. // Returns new token, otherwise returns error if the token already exists. -func (s *tokenStatements) insertToken( +func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, token, localpart string, @@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken( // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // Returns the existing token's attributes, or err if no token is found -func (s *tokenStatements) selectOpenIDTokenAtrributes( +func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go similarity index 89% rename from userapi/storage/accounts/sqlite3/profile_table.go rename to userapi/storage/sqlite3/profile_table.go index a92e95663..d85b19c7b 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const profilesSchema = ` @@ -60,13 +61,15 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(profilesSchema) - if err != nil { - return +func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) { + s := &profilesStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(profilesSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertProfileStmt, insertProfileSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL}, @@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *profilesStatements) insertProfile( +func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) error { _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return err } -func (s *profilesStatements) selectProfileByLocalpart( +func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { var profile authtypes.Profile @@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart( return &profile, nil } -func (s *profilesStatements) setAvatarURL( +func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) @@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL( return } -func (s *profilesStatements) setDisplayName( +func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) @@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName( return } -func (s *profilesStatements) selectProfilesBySearch( +func (s *profilesStatements) SelectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go new file mode 100644 index 000000000..98c244977 --- /dev/null +++ b/userapi/storage/sqlite3/storage.go @@ -0,0 +1,106 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/userapi/storage/shared" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + + // Import the postgres database driver. + _ "github.com/lib/pq" +) + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + + m := sqlutil.NewMigrations() + if _, err = db.Exec(accountsSchema); err != nil { + // do this so that the migration can and we don't fail on + // preparing statements for columns that don't exist yet + return nil, err + } + deltas.LoadIsActive(m) + //deltas.LoadLastSeenTSIP(m) + deltas.LoadAddAccountType(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + + accountDataTable, err := NewSQLiteAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) + } + accountsTable, err := NewSQLiteAccountsTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) + } + devicesTable, err := NewSQLiteDevicesTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) + } + keyBackupTable, err := NewSQLiteKeyBackupTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err) + } + keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err) + } + loginTokenTable, err := NewSQLiteLoginTokenTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err) + } + openIDTable, err := NewSQLiteOpenIDTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err) + } + profilesTable, err := NewSQLiteProfilesTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err) + } + threePIDTable, err := NewSQLiteThreePIDTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) + } + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + ServerName: serverName, + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil +} diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go similarity index 88% rename from userapi/storage/accounts/sqlite3/threepid_table.go rename to userapi/storage/sqlite3/threepid_table.go index 9dc0e2d22..fa174eed5 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -60,13 +61,15 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(threepidSchema) - if err != nil { - return +func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(threepidSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return threepids, rows.Err() } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID( return err } -func (s *threepidStatements) deleteThreePID( +func (s *threepidStatements) DeleteThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium) diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/storage.go similarity index 81% rename from userapi/storage/accounts/storage.go rename to userapi/storage/storage.go index a21f7d94e..4711439af 100644 --- a/userapi/storage/accounts/storage.go +++ b/userapi/storage/storage.go @@ -15,25 +15,27 @@ //go:build !wasm // +build !wasm -package accounts +package storage import ( "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/storage_wasm.go similarity index 87% rename from userapi/storage/accounts/storage_wasm.go rename to userapi/storage/storage_wasm.go index 11a88a20a..701dcd833 100644 --- a/userapi/storage/accounts/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package storage import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) @@ -27,10 +28,11 @@ func NewDatabase( serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go new file mode 100644 index 000000000..12939ced5 --- /dev/null +++ b/userapi/storage/tables/interface.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/userapi/api" +) + +type AccountDataTable interface { + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) +} + +type AccountsTable interface { + InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string) (err error) + SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) +} + +type DevicesTable interface { + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error + SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) + SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) + SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error +} + +type KeyBackupTable interface { + CountKeys(ctx context.Context, txn *sql.Tx, userID, version string) (count int64, err error) + InsertBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error) + UpdateBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error) + SelectKeys(ctx context.Context, txn *sql.Tx, userID, version string) (map[string]map[string]api.KeyBackupSession, error) + SelectKeysByRoomID(ctx context.Context, txn *sql.Tx, userID, version, roomID string) (map[string]map[string]api.KeyBackupSession, error) + SelectKeysByRoomIDAndSessionID(ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string) (map[string]map[string]api.KeyBackupSession, error) +} + +type KeyBackupVersionTable interface { + InsertKeyBackup(ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage) error + UpdateKeyBackupETag(ctx context.Context, txn *sql.Tx, userID, version, etag string) error + DeleteKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (bool, error) + SelectKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) +} + +type LoginTokenTable interface { + InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error + DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error + SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) +} + +type OpenIDTable interface { + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) + SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) +} + +type ProfileTable interface { + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error + SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) +} + +type ThreePIDTable interface { + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) + DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) +} diff --git a/userapi/userapi.go b/userapi/userapi.go index c7e1f6674..4a5793abb 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -23,18 +23,10 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" ) -// defaultLoginTokenLifetime determines how old a valid token may be. -// -// NOTSPEC: The current spec says "SHOULD be limited to around five -// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. -// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). -const defaultLoginTokenLifetime = 2 * time.Minute - // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { @@ -44,26 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { - deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) + db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } - return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) + return newInternalAPI(db, cfg, appServices, keyAPI) } func newInternalAPI( - accountDB accounts.Database, - deviceDB devices.Database, + db storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { return &internal.UserInternalAPI{ - AccountDB: accountDB, - DeviceDB: deviceDB, + DB: db, ServerName: cfg.Matrix.ServerName, AppServices: appServices, KeyAPI: keyAPI, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 5dd217828..fac6aefbe 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -23,15 +23,15 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/userapi/storage" ) const ( @@ -42,23 +42,19 @@ type apiTestOpts struct { loginTokenLifetime time.Duration } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) { +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) { if opts.loginTokenLifetime == 0 { - opts.loginTokenLifetime = defaultLoginTokenLifetime + opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } dbopts := &config.DatabaseOptions{ ConnectionString: "file::memory:", MaxOpenConnections: 1, MaxIdleConnections: 1, } - accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) + accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime) if err != nil { t.Fatalf("failed to create account DB: %s", err) } - deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime) - if err != nil { - t.Fatalf("failed to create device DB: %s", err) - } cfg := &config.UserAPI{ Matrix: &config.Global{ @@ -66,14 +62,14 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, a }, } - return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB + return newInternalAPI(accountDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) - _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "") + _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } @@ -151,7 +147,7 @@ func TestLoginToken(t *testing.T) { t.Run("tokenLoginFlow", func(t *testing.T) { userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) - _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "") + _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) }