Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking

This commit is contained in:
Till Faelligen 2022-02-21 12:08:03 +01:00
commit 9c3a1cfd47
122 changed files with 2344 additions and 2320 deletions

View file

@ -23,7 +23,7 @@ import (
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@ -85,7 +85,7 @@ func RetrieveUserProfile(
ctx context.Context,
userID string,
asAPI AppServiceQueryAPI,
accountDB accounts.Database,
accountDB userdb.Database,
) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {

View file

@ -22,6 +22,8 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/appservice/consumers"
"github.com/matrix-org/dendrite/appservice/inthttp"
@ -34,7 +36,6 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/sirupsen/logrus"
)
// AddInternalRoutes registers HTTP handlers for internal API calls
@ -121,7 +122,7 @@ func generateAppServiceAccount(
) error {
var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser,
AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart,
AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate,

View file

@ -283,8 +283,7 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))

View file

@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))

View file

@ -28,7 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@ -38,7 +38,7 @@ func AddPublicRoutes(
synapseAdminRouter *mux.Router,
consentAPIMux *mux.Router,
cfg *config.ClientAPI,
accountsDB accounts.Database,
accountsDB userdb.Database,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI,
eduInputAPI eduServerAPI.EDUServerInputAPI,

View file

@ -156,6 +156,15 @@ func MissingParam(msg string) *MatrixError {
return &MatrixError{"M_MISSING_PARAM", msg}
}
// LeaveServerNoticeError is an error returned when trying to reject an invite
// for a server notice room.
func LeaveServerNoticeError() *MatrixError {
return &MatrixError{
ErrCode: "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM",
Err: "You cannot reject this invite",
}
}
type IncompatibleRoomVersionError struct {
RoomVersion string `json:"room_version"`
Error string `json:"error"`

View file

@ -47,8 +47,8 @@ func GetAdminWhois(
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string,
) util.JSONResponse {
if userID != device.UserID {
// TODO: Still allow if user is admin
allowed := device.AccountType == api.AccountTypeAdmin || userID == device.UserID
if !allowed {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("userID does not match the current user"),

View file

@ -15,6 +15,7 @@
package routing
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -30,7 +31,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
@ -137,36 +138,17 @@ type fledglingEvent struct {
func CreateRoom(
req *http.Request, device *api.Device,
cfg *config.ClientAPI,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs?
roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
return createRoom(req, device, cfg, roomID, accountDB, rsAPI, asAPI)
}
// createRoom implements /createRoom
// nolint: gocyclo
func createRoom(
req *http.Request, device *api.Device,
cfg *config.ClientAPI, roomID string,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
userID := device.UserID
var r createRoomRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
// TODO: apply rate-limit
if resErr = r.Validate(); resErr != nil {
return *resErr
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return util.JSONResponse{
@ -174,6 +156,25 @@ func createRoom(
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
return createRoom(req.Context(), r, device, cfg, accountDB, rsAPI, asAPI, evTime)
}
// createRoom implements /createRoom
// nolint: gocyclo
func createRoom(
ctx context.Context,
r createRoomRequest, device *api.Device,
cfg *config.ClientAPI,
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
evTime time.Time,
) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs?
roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
logger := util.GetLogger(ctx)
userID := device.UserID
// Clobber keys: creator, room_version
@ -200,16 +201,16 @@ func createRoom(
"roomVersion": roomVersion,
}).Info("Creating new room")
profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB)
profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed")
util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed")
return jsonerror.InternalServerError()
}
createContent := map[string]interface{}{}
if len(r.CreationContent) > 0 {
if err = json.Unmarshal(r.CreationContent, &createContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for creation_content failed")
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("invalid create content"),
@ -230,7 +231,7 @@ func createRoom(
// Merge powerLevelContentOverride fields by unmarshalling it atop the defaults
err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for power_level_content_override failed")
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("malformed power_level_content_override"),
@ -319,9 +320,9 @@ func createRoom(
}
var aliasResp roomserverAPI.GetRoomIDForAliasResponse
err = rsAPI.GetRoomIDForAlias(req.Context(), &hasAliasReq, &aliasResp)
err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed")
util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed")
return jsonerror.InternalServerError()
}
if aliasResp.RoomID != "" {
@ -426,7 +427,7 @@ func createRoom(
}
err = builder.SetContent(e.Content)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed")
util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
return jsonerror.InternalServerError()
}
if i > 0 {
@ -435,12 +436,12 @@ func createRoom(
var ev *gomatrixserverlib.Event
ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("buildEvent failed")
util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
return jsonerror.InternalServerError()
}
if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed")
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return jsonerror.InternalServerError()
}
@ -448,7 +449,7 @@ func createRoom(
builtEvents = append(builtEvents, ev.Headered(roomVersion))
err = authEvents.AddEvent(ev)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed")
util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed")
return jsonerror.InternalServerError()
}
}
@ -462,8 +463,8 @@ func createRoom(
SendAsServer: roomserverAPI.DoNotSendToOtherServers,
})
}
if err = roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, inputs, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return jsonerror.InternalServerError()
}
@ -478,9 +479,9 @@ func createRoom(
}
var aliasResp roomserverAPI.SetRoomAliasResponse
err = rsAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp)
err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed")
util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return jsonerror.InternalServerError()
}
@ -519,11 +520,11 @@ func createRoom(
for _, invitee := range r.Invite {
// Build the invite event.
inviteEvent, err := buildMembershipEvent(
req.Context(), invitee, "", accountDB, device, gomatrixserverlib.Invite,
ctx, invitee, "", accountDB, device, gomatrixserverlib.Invite,
roomID, true, cfg, evTime, rsAPI, asAPI,
)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed")
util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed")
continue
}
inviteStrippedState := append(
@ -532,7 +533,7 @@ func createRoom(
)
// Send the invite event to the roomserver.
err = roomserverAPI.SendInvite(
req.Context(),
ctx,
rsAPI,
inviteEvent.Headered(roomVersion),
inviteStrippedState, // invite room state
@ -544,7 +545,7 @@ func createRoom(
return e.JSONResponse()
case nil:
default:
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed")
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
@ -556,13 +557,13 @@ func createRoom(
if r.Visibility == "public" {
// expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID,
Visibility: "public",
}, &pubRes)
if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(req.Context()).WithError(pubRes.Error).Error("failed to visibility:public")
util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")
}
}

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
roomIDOrAlias string,
) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join.

View file

@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util"
)
@ -36,7 +36,7 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive,
keyserverAPI api.KeyInternalAPI, device *userapi.Device,
accountDB accounts.Database, cfg *config.ClientAPI,
accountDB userdb.Database, cfg *config.ClientAPI,
) util.JSONResponse {
uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{}

View file

@ -38,6 +38,12 @@ func LeaveRoomByID(
// Ask the roomserver to perform the leave.
if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil {
if leaveRes.Code != 0 {
return util.JSONResponse{
Code: leaveRes.Code,
JSON: jsonerror.LeaveServerNoticeError(),
}
}
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(err.Error()),

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -54,7 +54,7 @@ func passwordLogin() flows {
// Login implements GET and POST /login
func Login(
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
if req.Method == http.MethodGet {

View file

@ -30,7 +30,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -39,7 +39,7 @@ import (
var errMissingUserID = errors.New("'user_id' must be supplied")
func SendBan(
req *http.Request, accountDB accounts.Database, device *userapi.Device,
req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@ -81,7 +81,7 @@ func SendBan(
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
}
func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device,
func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device,
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
roomVer gomatrixserverlib.RoomVersion,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
}
func SendKick(
req *http.Request, accountDB accounts.Database, device *userapi.Device,
req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@ -165,7 +165,7 @@ func SendKick(
}
func SendUnban(
req *http.Request, accountDB accounts.Database, device *userapi.Device,
req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@ -200,7 +200,7 @@ func SendUnban(
}
func SendInvite(
req *http.Request, accountDB accounts.Database, device *userapi.Device,
req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@ -226,27 +226,42 @@ func SendInvite(
}
}
// We already received the return value, so no need to check for an error here.
response, _ := sendInvite(req.Context(), accountDB, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime)
return response
}
// sendInvite sends an invitation to a user. Returns a JSONResponse and an error
func sendInvite(
ctx context.Context,
accountDB userdb.Database,
device *userapi.Device,
roomID, userID, reason string,
cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, evTime time.Time,
) (util.JSONResponse, error) {
event, err := buildMembershipEvent(
req.Context(), body.UserID, body.Reason, accountDB, device, "invite",
ctx, userID, reason, accountDB, device, "invite",
roomID, false, cfg, evTime, rsAPI, asAPI,
)
if err == errMissingUserID {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
}, err
} else if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(err.Error()),
}
}, err
} else if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed")
return jsonerror.InternalServerError()
util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed")
return jsonerror.InternalServerError(), err
}
err = roomserverAPI.SendInvite(
req.Context(), rsAPI,
ctx, rsAPI,
event,
nil, // ask the roomserver to draw up invite room state for us
cfg.Matrix.ServerName,
@ -254,24 +269,24 @@ func SendInvite(
)
switch e := err.(type) {
case *roomserverAPI.PerformError:
return e.JSONResponse()
return e.JSONResponse(), err
case nil:
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}, nil
default:
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed")
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
}
}, err
}
}
func buildMembershipEvent(
ctx context.Context,
targetUserID, reason string, accountDB accounts.Database,
targetUserID, reason string, accountDB userdb.Database,
device *userapi.Device,
membership, roomID string, isDirect bool,
cfg *config.ClientAPI, evTime time.Time,
@ -312,7 +327,7 @@ func loadProfile(
ctx context.Context,
userID string,
cfg *config.ClientAPI,
accountDB accounts.Database,
accountDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -366,7 +381,7 @@ func checkAndProcessThreepid(
body *threepid.MembershipRequest,
cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
roomID string,
evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) {

View file

@ -9,7 +9,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -29,7 +29,7 @@ type newPasswordAuth struct {
func Password(
req *http.Request,
userAPI api.UserInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {

View file

@ -19,7 +19,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
roomIDOrAlias string,
) util.JSONResponse {
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
@ -82,7 +82,7 @@ func UnpeekRoomByID(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
roomID string,
) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{

View file

@ -27,7 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix"
@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID}
func GetProfile(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
@ -65,7 +65,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@ -92,7 +92,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL(
req *http.Request, accountDB accounts.Database,
req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse {
if userID != device.UserID {
@ -182,7 +182,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@ -209,7 +209,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName(
req *http.Request, accountDB accounts.Database,
req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse {
if userID != device.UserID {
@ -302,7 +302,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically
// eventutil.ErrProfileNoExists when the profile doesn't exist.
func getProfile(
ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI,
ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,

View file

@ -32,18 +32,19 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
userdb "github.com/matrix-org/dendrite/userapi/storage"
)
var (
@ -153,7 +154,7 @@ type authDict struct {
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api
type userInteractiveResponse struct {
Flows []authtypes.Flow `json:"flows"`
Completed []authtypes.LoginType `json:"completed,omitempty"`
Completed []authtypes.LoginType `json:"completed"`
Params map[string]interface{} `json:"params"`
Session string `json:"session"`
}
@ -447,7 +448,7 @@ func validateApplicationService(
func Register(
req *http.Request,
userAPI userapi.UserInternalAPI,
accountDB accounts.Database,
accountDB userdb.Database,
cfg *config.ClientAPI,
) util.JSONResponse {
var r registerRequest
@ -531,6 +532,13 @@ func handleGuestRegistration(
cfg *config.ClientAPI,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
if cfg.RegistrationDisabled || cfg.GuestsDisabled {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Guest registration is disabled"),
}
}
var res userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeGuest,
@ -708,7 +716,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate.
return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
)
}
@ -732,7 +740,7 @@ func checkAndCompleteFlow(
return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
)
}
@ -757,6 +765,7 @@ func completeRegistration(
username, password, appserviceID, ipAddr, userAgent, policyVersion string,
inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string,
accType userapi.AccountType,
) util.JSONResponse {
if username == "" {
return util.JSONResponse{
@ -771,13 +780,12 @@ func completeRegistration(
JSON: jsonerror.BadJSON("missing password"),
}
}
var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AppServiceID: appserviceID,
Localpart: username,
Password: password,
AccountType: userapi.AccountTypeUser,
AccountType: accType,
OnConflict: userapi.ConflictAbort,
PolicyVersion: policyVersion,
}, &accRes)
@ -904,7 +912,7 @@ type availableResponse struct {
func RegisterAvailable(
req *http.Request,
cfg *config.ClientAPI,
accountDB accounts.Database,
accountDB userdb.Database,
) util.JSONResponse {
username := req.URL.Query().Get("username")
@ -976,5 +984,10 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
return *resErr
}
deviceID := "shared_secret_registration"
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID)
accType := userapi.AccountTypeUser
if ssrr.Admin {
accType = userapi.AccountTypeAdmin
}
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
}

View file

@ -15,6 +15,7 @@
package routing
import (
"context"
"encoding/json"
"net/http"
"strings"
@ -34,7 +35,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@ -51,7 +52,7 @@ func Setup(
eduAPI eduServerAPI.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
accountDB accounts.Database,
accountDB userdb.Database,
userAPI userapi.UserInternalAPI,
federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer,
@ -117,6 +118,58 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
}
// server notifications
if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, accountDB, cfg)
if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
}
synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}",
httputil.MakeAuthAPI("send_server_notice", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
// not specced, but ensure we're rate limiting requests to this endpoint
if r := rateLimits.Limit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
txnID := vars["txnID"]
return SendServerNotice(
req, &cfg.Matrix.ServerNotices,
cfg, userAPI, rsAPI, accountDB, asAPI,
device, serverNotificationSender,
&txnID, transactionsCache,
)
}),
).Methods(http.MethodPut, http.MethodOptions)
synapseAdminRouter.Handle("/admin/v1/send_server_notice",
httputil.MakeAuthAPI("send_server_notice", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
// not specced, but ensure we're rate limiting requests to this endpoint
if r := rateLimits.Limit(req); r != nil {
return *r
}
return SendServerNotice(
req, &cfg.Matrix.ServerNotices,
cfg, userAPI, rsAPI, accountDB, asAPI,
device, serverNotificationSender,
nil, transactionsCache,
)
}),
).Methods(http.MethodPost, http.MethodOptions)
}
// You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables.
// So make a named path variable called 'apiversion' (which we will never read in handlers) and then do
// (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group
// from a match (gorilla/mux exposes no way to do this) so it demands you make it a non-capturing group
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
// unspecced consent tracking
if cfg.Matrix.UserConsentOptions.Enabled {
consentAPIMux.Handle("/consent",
@ -129,12 +182,12 @@ func Setup(
r0mux := publicAPIMux.PathPrefix("/r0").Subrouter()
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
r0mux.Handle("/createRoom",
v3mux.Handle("/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/join/{roomIDOrAlias}",
v3mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -150,7 +203,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions)
if mscCfg.Enabled("msc2753") {
r0mux.Handle("/peek/{roomIDOrAlias}",
v3mux.Handle("/peek/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -165,12 +218,12 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
}
r0mux.Handle("/joined_rooms",
v3mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, rsAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/join",
v3mux.Handle("/rooms/{roomID}/join",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -184,7 +237,7 @@ func Setup(
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/leave",
v3mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -198,7 +251,7 @@ func Setup(
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/unpeek",
v3mux.Handle("/rooms/{roomID}/unpeek",
httputil.MakeAuthAPI("unpeek", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -209,7 +262,7 @@ func Setup(
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/ban",
v3mux.Handle("/rooms/{roomID}/ban",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -218,7 +271,7 @@ func Setup(
return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/invite",
v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -230,7 +283,7 @@ func Setup(
return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/kick",
v3mux.Handle("/rooms/{roomID}/kick",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -239,7 +292,7 @@ func Setup(
return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/unban",
v3mux.Handle("/rooms/{roomID}/unban",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -248,7 +301,7 @@ func Setup(
return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}",
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -257,7 +310,7 @@ func Setup(
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -268,7 +321,7 @@ func Setup(
nil, cfg, rsAPI, transactionsCache)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/event/{eventID}",
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
httputil.MakeAuthAPI("rooms_get_event", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -278,7 +331,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -286,7 +339,7 @@ func Setup(
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -294,7 +347,7 @@ func Setup(
return GetAliases(req, rsAPI, device, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -305,7 +358,7 @@ func Setup(
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -314,7 +367,7 @@ func Setup(
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -326,7 +379,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -337,21 +390,21 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
return Register(req, userAPI, accountDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
return RegisterAvailable(req, cfg, accountDB)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -361,7 +414,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -371,7 +424,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -380,7 +433,7 @@ func Setup(
return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI)
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/directory/list/room/{roomID}",
v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -390,7 +443,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
// TODO: Add AS support
r0mux.Handle("/directory/list/room/{roomID}",
v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeAuthAPI("directory_list", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -399,25 +452,25 @@ func Setup(
return SetVisibility(req, rsAPI, device, vars["roomID"])
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/publicRooms",
v3mux.Handle("/publicRooms",
httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout",
v3mux.Handle("/logout",
httputil.MakeAuthAPI("logout", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Logout(req, userAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout/all",
v3mux.Handle("/logout/all",
httputil.MakeAuthAPI("logout", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return LogoutAll(req, userAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/typing/{userID}",
v3mux.Handle("/rooms/{roomID}/typing/{userID}",
httputil.MakeAuthAPI("rooms_typing", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -429,7 +482,7 @@ func Setup(
return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}",
v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
httputil.MakeAuthAPI("rooms_redact", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -438,7 +491,7 @@ func Setup(
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
httputil.MakeAuthAPI("rooms_redact", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -448,7 +501,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/sendToDevice/{eventType}/{txnID}",
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_to_device", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -473,7 +526,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/account/whoami",
v3mux.Handle("/account/whoami",
httputil.MakeAuthAPI("whoami", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -482,7 +535,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/password",
v3mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -491,7 +544,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/account/deactivate",
v3mux.Handle("/account/deactivate",
httputil.MakeAuthAPI("deactivate", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -502,7 +555,7 @@ func Setup(
// Stub endpoints required by Element
r0mux.Handle("/login",
v3mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -511,14 +564,14 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/auth/{authType}/fallback/web",
v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/pushrules/",
v3mux.Handle("/pushrules/",
httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
// TODO: Implement push rules API
res := json.RawMessage(`{
@ -539,7 +592,7 @@ func Setup(
// Element user settings
r0mux.Handle("/profile/{userID}",
v3mux.Handle("/profile/{userID}",
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -549,7 +602,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url",
v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -559,7 +612,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url",
v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeAuthAPI("profile_avatar_url", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -574,7 +627,7 @@ func Setup(
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method
r0mux.Handle("/profile/{userID}/displayname",
v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -584,7 +637,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/displayname",
v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeAuthAPI("profile_displayname", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -599,13 +652,13 @@ func Setup(
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method
r0mux.Handle("/account/3pid",
v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAssociated3PIDs(req, accountDB, device)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/3pid",
v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, accountDB, device, cfg)
}),
@ -617,14 +670,14 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
return RequestEmailToken(req, accountDB, cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Element logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status",
v3mux.Handle("/presence/{userID}/status",
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -637,7 +690,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/voip/turnServer",
v3mux.Handle("/voip/turnServer",
httputil.MakeAuthAPI("turn_server", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -646,7 +699,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/thirdparty/protocols",
v3mux.Handle("/thirdparty/protocols",
httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
// TODO: Return the third party protcols
return util.JSONResponse{
@ -656,7 +709,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/initialSync",
v3mux.Handle("/rooms/{roomID}/initialSync",
httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse {
// TODO: Allow people to peek into rooms.
return util.JSONResponse{
@ -666,7 +719,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}",
v3mux.Handle("/user/{userID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -676,7 +729,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -686,7 +739,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}",
v3mux.Handle("/user/{userID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -696,7 +749,7 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -706,7 +759,7 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/admin/whois/{userID}",
v3mux.Handle("/admin/whois/{userID}",
httputil.MakeAuthAPI("admin_whois", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -716,7 +769,7 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/openid/request_token",
v3mux.Handle("/user/{userID}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -729,7 +782,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user_directory/search",
v3mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -754,7 +807,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/members",
v3mux.Handle("/rooms/{roomID}/members",
httputil.MakeAuthAPI("rooms_members", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -764,7 +817,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/joined_members",
v3mux.Handle("/rooms/{roomID}/joined_members",
httputil.MakeAuthAPI("rooms_members", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -774,7 +827,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/read_markers",
v3mux.Handle("/rooms/{roomID}/read_markers",
httputil.MakeAuthAPI("rooms_read_markers", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -787,7 +840,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/forget",
v3mux.Handle("/rooms/{roomID}/forget",
httputil.MakeAuthAPI("rooms_forget", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -800,13 +853,13 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/devices",
v3mux.Handle("/devices",
httputil.MakeAuthAPI("get_devices", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, userAPI, device)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("get_device", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -816,7 +869,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("device_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -826,7 +879,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("delete_device", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -836,14 +889,14 @@ func Setup(
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/delete_devices",
v3mux.Handle("/delete_devices",
httputil.MakeAuthAPI("delete_devices", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return DeleteDevices(req, userAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest
r0mux.Handle("/events",
v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"chunk": []interface{}{},
@ -853,7 +906,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/initialSync",
v3mux.Handle("/initialSync",
httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"end": "",
@ -861,7 +914,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags",
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags",
httputil.MakeAuthAPI("get_tags", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -871,7 +924,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
httputil.MakeAuthAPI("put_tag", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -881,7 +934,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
httputil.MakeAuthAPI("delete_tag", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -891,7 +944,7 @@ func Setup(
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/capabilities",
v3mux.Handle("/capabilities",
httputil.MakeAuthAPI("capabilities", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
@ -934,11 +987,11 @@ func Setup(
return CreateKeyBackupVersion(req, userAPI, device)
})
r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut)
r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete)
r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut)
v3mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete)
v3mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
@ -1030,9 +1083,9 @@ func Setup(
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
})
r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut)
v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut)
unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
@ -1060,9 +1113,9 @@ func Setup(
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
})
r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
@ -1080,29 +1133,29 @@ func Setup(
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
})
r0mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
// Supplying a device ID is deprecated.
r0mux.Handle("/keys/upload/{deviceID}",
v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/upload",
v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/query",
v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, keyAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/claim",
v3mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return ClaimKeys(req, keyAPI)
}),

View file

@ -15,10 +15,16 @@
package routing
import (
"context"
"net/http"
"sync"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
@ -26,10 +32,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
// http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid
@ -97,7 +99,22 @@ func SendEvent(
defer mutex.(*sync.Mutex).Unlock()
startedGeneratingEvent := time.Now()
e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI)
var r map[string]interface{} // must be a JSON object
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, cfg, rsAPI, evTime)
if resErr != nil {
return *resErr
}
@ -153,27 +170,16 @@ func SendEvent(
}
func generateSendEvent(
req *http.Request,
ctx context.Context,
r map[string]interface{},
device *userapi.Device,
roomID, eventType string, stateKey *string,
cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI,
evTime time.Time,
) (*gomatrixserverlib.Event, *util.JSONResponse) {
// parse the incoming http request
userID := device.UserID
var r map[string]interface{} // must be a JSON object
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return nil, resErr
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
// create the new event and set all the fields we can
builder := gomatrixserverlib.EventBuilder{
@ -182,15 +188,15 @@ func generateSendEvent(
Type: eventType,
StateKey: stateKey,
}
err = builder.SetContent(r)
err := builder.SetContent(r)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed")
util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
resErr := jsonerror.InternalServerError()
return nil, &resErr
}
var queryRes api.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes)
e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists {
return nil, &util.JSONResponse{
Code: http.StatusNotFound,
@ -213,7 +219,7 @@ func generateSendEvent(
JSON: jsonerror.BadJSON(e.Error()),
}
} else if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("eventutil.BuildEvent failed")
util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed")
resErr := jsonerror.InternalServerError()
return nil, &resErr
}

View file

@ -20,7 +20,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util"
)
@ -33,7 +33,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer
func SendTyping(
req *http.Request, device *userapi.Device, roomID string,
userID string, accountDB accounts.Database,
userID string, accountDB userdb.Database,
eduAPI api.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
) util.JSONResponse {

View file

@ -0,0 +1,343 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/transactions"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
// Unspecced server notice request
// https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/server_notices.md
type sendServerNoticeRequest struct {
UserID string `json:"user_id,omitempty"`
Content struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
} `json:"content,omitempty"`
Type string `json:"type,omitempty"`
StateKey string `json:"state_key,omitempty"`
}
// SendServerNotice sends a message to a specific user. It can only be invoked by an admin.
func SendServerNotice(
req *http.Request,
cfgNotices *config.ServerNotices,
cfgClient *config.ClientAPI,
userAPI userapi.UserInternalAPI,
rsAPI api.RoomserverInternalAPI,
accountsDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI,
device *userapi.Device,
senderDevice *userapi.Device,
txnID *string,
txnCache *transactions.Cache,
) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
if txnID != nil {
// Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
return *res
}
}
ctx := req.Context()
var r sendServerNoticeRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
// check that all required fields are set
if !r.valid() {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid request"),
}
}
// get rooms for specified user
allUserRooms := []string{}
userRooms := api.QueryRoomsForUserResponse{}
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: r.UserID,
WantMembership: "join",
}, &userRooms); err != nil {
return util.ErrorResponse(err)
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
// get invites for specified user
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: r.UserID,
WantMembership: "invite",
}, &userRooms); err != nil {
return util.ErrorResponse(err)
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
// get left rooms for specified user
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: r.UserID,
WantMembership: "leave",
}, &userRooms); err != nil {
return util.ErrorResponse(err)
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
// get rooms of the sender
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName)
senderRooms := api.QueryRoomsForUserResponse{}
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: senderUserID,
WantMembership: "join",
}, &senderRooms); err != nil {
return util.ErrorResponse(err)
}
// check if we have rooms in common
commonRooms := []string{}
for _, userRoomID := range allUserRooms {
for _, senderRoomID := range senderRooms.RoomIDs {
if userRoomID == senderRoomID {
commonRooms = append(commonRooms, senderRoomID)
}
}
}
if len(commonRooms) > 1 {
return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms)))
}
var (
roomID string
roomVersion = gomatrixserverlib.RoomVersionV6
)
// create a new room for the user
if len(commonRooms) == 0 {
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID)
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
pl, err := json.Marshal(powerLevelContent)
if err != nil {
return util.ErrorResponse(err)
}
createContent := map[string]interface{}{}
createContent["m.federate"] = false
cc, err := json.Marshal(createContent)
if err != nil {
return util.ErrorResponse(err)
}
crReq := createRoomRequest{
Invite: []string{r.UserID},
Name: cfgNotices.RoomName,
Visibility: "private",
Preset: presetPrivateChat,
CreationContent: cc,
GuestCanJoin: false,
RoomVersion: roomVersion,
PowerLevelContentOverride: pl,
}
roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, accountsDB, rsAPI, asAPI, time.Now())
switch data := roomRes.JSON.(type) {
case createRoomResponse:
roomID = data.RoomID
// tag the room, so we can later check if the user tries to reject an invite
serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{
"m.server_notice": {
Order: 1.0,
},
}}
if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil {
util.GetLogger(ctx).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
}
default:
// if we didn't get a createRoomResponse, we probably received an error, so return that.
return roomRes
}
} else {
// we've found a room in common, check the membership
roomID = commonRooms[0]
// re-invite the user
res, err := sendInvite(ctx, accountsDB, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
if err != nil {
return res
}
}
startedGeneratingEvent := time.Now()
request := map[string]interface{}{
"body": r.Content.Body,
"msgtype": r.Content.MsgType,
}
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now())
if resErr != nil {
logrus.Errorf("failed to send message: %+v", resErr)
return *resErr
}
timeToGenerateEvent := time.Since(startedGeneratingEvent)
var txnAndSessionID *api.TransactionID
if txnID != nil {
txnAndSessionID = &api.TransactionID{
TransactionID: *txnID,
SessionID: device.SessionID,
}
}
// pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded
startedSubmittingEvent := time.Now()
if err := api.SendEvents(
ctx, rsAPI,
api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{
e.Headered(roomVersion),
},
cfgClient.Matrix.ServerName,
cfgClient.Matrix.ServerName,
txnAndSessionID,
false,
); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError()
}
util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": e.EventID(),
"room_id": roomID,
"room_version": roomVersion,
}).Info("Sent event to roomserver")
timeToSubmitEvent := time.Since(startedSubmittingEvent)
res := util.JSONResponse{
Code: http.StatusOK,
JSON: sendEventResponse{e.EventID()},
}
// Add response to transactionsCache
if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
}
// Take a note of how long it took to generate the event vs submit
// it to the roomserver.
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
return res
}
func (r sendServerNoticeRequest) valid() (ok bool) {
if r.UserID == "" {
return false
}
if r.Content.MsgType == "" || r.Content.Body == "" {
return false
}
return true
}
// getSenderDevice creates a user account to be used when sending server notices.
// It returns an userapi.Device, which is used for building the event
func getSenderDevice(
ctx context.Context,
userAPI userapi.UserInternalAPI,
accountDB userdb.Database,
cfg *config.ClientAPI,
) (*userapi.Device, error) {
var accRes userapi.PerformAccountCreationResponse
// create account if it doesn't exist
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser,
Localpart: cfg.Matrix.ServerNotices.LocalPart,
OnConflict: userapi.ConflictUpdate,
}, &accRes)
if err != nil {
return nil, err
}
// set the avatarurl for the user
if err = accountDB.SetAvatarURL(ctx, cfg.Matrix.ServerNotices.LocalPart, cfg.Matrix.ServerNotices.AvatarURL); err != nil {
util.GetLogger(ctx).WithError(err).Error("accountDB.SetAvatarURL failed")
return nil, err
}
// Check if we got existing devices
deviceRes := &userapi.QueryDevicesResponse{}
err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{
UserID: accRes.Account.UserID,
}, deviceRes)
if err != nil {
return nil, err
}
if len(deviceRes.Devices) > 0 {
return &deviceRes.Devices[0], nil
}
// create an AccessToken
token, err := tokens.GenerateLoginToken(tokens.TokenOptions{
ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(),
ServerName: string(cfg.Matrix.ServerName),
UserID: accRes.Account.UserID,
})
if err != nil {
return nil, err
}
// create a new device, if we didn't find any
var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
AccessToken: token,
NoDeviceListUpdate: true,
}, &devRes)
if err != nil {
return nil, err
}
return devRes.Device, nil
}

View file

@ -0,0 +1,83 @@
package routing
import (
"testing"
)
func Test_sendServerNoticeRequest_validate(t *testing.T) {
type fields struct {
UserID string `json:"user_id,omitempty"`
Content struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
} `json:"content,omitempty"`
Type string `json:"type,omitempty"`
StateKey string `json:"state_key,omitempty"`
}
content := struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
}{
MsgType: "m.text",
Body: "Hello world!",
}
tests := []struct {
name string
fields fields
wantOk bool
}{
{
name: "empty request",
fields: fields{},
},
{
name: "msgtype empty",
fields: fields{
UserID: "@alice:localhost",
Content: struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
}{
Body: "Hello world!",
},
},
},
{
name: "msg body empty",
fields: fields{
UserID: "@alice:localhost",
},
},
{
name: "statekey empty",
fields: fields{
UserID: "@alice:localhost",
Content: content,
},
wantOk: true,
},
{
name: "type empty",
fields: fields{
UserID: "@alice:localhost",
Content: content,
},
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := sendServerNoticeRequest{
UserID: tt.fields.UserID,
Content: tt.fields.Content,
Type: tt.fields.Type,
StateKey: tt.fields.StateKey,
}
if gotOk := r.valid(); gotOk != tt.wantOk {
t.Errorf("valid() = %v, want %v", gotOk, tt.wantOk)
}
})
}
}

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -40,7 +40,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements:
// POST /account/3pid/email/requestToken
// POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse {
func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse {
var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr
@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_IN_USE",
Err: accounts.Err3PIDInUse.Error(),
Err: userdb.Err3PIDInUse.Error(),
},
}
}
@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
// CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation(
req *http.Request, accountDB accounts.Database, device *api.Device,
req *http.Request, accountDB userdb.Database, device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest
@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs(
req *http.Request, accountDB accounts.Database, device *api.Device,
req *http.Request, accountDB userdb.Database, device *api.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@ -170,7 +170,7 @@ func GetAssociated3PIDs(
}
// Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse {
var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr

View file

@ -21,7 +21,9 @@ import (
// whoamiResponse represents an response for a `whoami` request
type whoamiResponse struct {
UserID string `json:"user_id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
IsGuest bool `json:"is_guest"`
}
// Whoami implements `/account/whoami` which enables client to query their account user id.
@ -29,6 +31,10 @@ type whoamiResponse struct {
func Whoami(req *http.Request, device *api.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusOK,
JSON: whoamiResponse{UserID: device.UserID},
JSON: whoamiResponse{
UserID: device.UserID,
DeviceID: device.ID,
IsGuest: device.AccountType == api.AccountTypeGuest,
},
}
}

View file

@ -29,7 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@ -87,7 +87,7 @@ var (
func CheckAndProcessInvite(
ctx context.Context,
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, db accounts.Database,
rsAPI api.RoomserverInternalAPI, db userdb.Database,
roomID string,
evTime time.Time,
) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed.
func queryIDServer(
ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite(
ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name

View file

@ -23,12 +23,14 @@ import (
"os"
"strings"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"golang.org/x/term"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
)
const usage = `Usage: %s
@ -57,6 +59,7 @@ var (
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
isAdmin = flag.Bool("admin", false, "Create an admin account")
)
func main() {
@ -74,19 +77,28 @@ func main() {
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS)
accountDB, err := userdb.NewDatabase(
&config.DatabaseOptions{
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
},
cfg.Global.ServerName, bcrypt.DefaultCost,
cfg.UserAPI.OpenIDTokenLifetimeMS,
api.DefaultLoginTokenLifetime,
)
if err != nil {
logrus.Fatalln("Failed to connect to the database:", err.Error())
}
accType := api.AccountTypeUser
if *isAdmin {
accType = api.AccountTypeAdmin
}
policyVersion := ""
if cfg.Global.UserConsentOptions.Enabled {
policyVersion = cfg.Global.UserConsentOptions.Version
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion)
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion, accType)
if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error())
}

View file

@ -126,7 +126,6 @@ func main() {
cfg.FederationAPI.FederationMaxRetries = 6
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -160,7 +160,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -79,7 +79,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -132,6 +132,7 @@ func main() {
// dependency. Other components also need updating after their dependencies are up.
rsImpl.SetFederationAPI(fsAPI, keyRing)
rsImpl.SetAppserviceAPI(asAPI)
rsImpl.SetUserAPI(userAPI)
keyImpl.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI(

View file

@ -164,7 +164,6 @@ func startup() {
cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"

View file

@ -167,7 +167,6 @@ func main() {
cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"

View file

@ -32,7 +32,6 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI)
}
cfg.Global.TrustedIDServers = []string{
"matrix.org",
@ -91,6 +90,7 @@ func main() {
cfg.Logging[0].Type = "std"
cfg.UserAPI.BCryptCost = bcrypt.MinCost
cfg.Global.JetStream.InMemory = true
cfg.ClientAPI.RegistrationSharedSecret = "complement"
}
j, err := yaml.Marshal(cfg)

View file

@ -8,12 +8,11 @@ import (
"log"
"os"
pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/pressly/goose"
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
@ -26,8 +25,7 @@ const (
RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi"
UserAPIAccounts = "userapi_accounts"
UserAPIDevices = "userapi_devices"
UserAPI = "userapi"
)
var (
@ -35,7 +33,7 @@ var (
flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name")
knownDBs = []string{
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices,
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
}
)
@ -143,18 +141,14 @@ Commands:
func loadSQLiteDeltas(component string) {
switch component {
case UserAPIAccounts:
slaccounts.LoadFromGoose()
case UserAPIDevices:
sldevices.LoadFromGoose()
case UserAPI:
slusers.LoadFromGoose()
}
}
func loadPostgresDeltas(component string) {
switch component {
case UserAPIAccounts:
pgaccounts.LoadFromGoose()
case UserAPIDevices:
pgdevices.LoadFromGoose()
case UserAPI:
pgusers.LoadFromGoose()
}
}

View file

@ -68,6 +68,18 @@ global:
# to other servers and the federation API will not be exposed.
disable_federation: false
# Server notices allows server admins to send messages to all users.
server_notices:
enabled: false
# The server localpart to be used when sending notices, ensure this is not yet taken
local_part: "_server"
# The displayname to be used when sending notices
display_name: "Server alerts"
# The mxid of the avatar to use
avatar_url: ""
# The roomname to be used when creating messages
room_name: "Server Alerts"
# Consent tracking configuration
user_consent:
# If the user consent tracking is enabled or not
@ -169,6 +181,10 @@ client_api:
# using the registration shared secret below.
registration_disabled: false
# Prevents new guest accounts from being created. Guest registration is also
# disabled implicitly by setting 'registration_disabled' above.
guests_disabled: true
# If set, allows registration by anyone who knows the shared secret, regardless of
# whether registration is otherwise disabled.
registration_shared_secret: ""
@ -231,13 +247,6 @@ federation_api:
# enable this option in production as it presents a security risk!
disable_tls_validation: false
# Use the following proxy server for outbound federation traffic.
proxy_outbound:
enabled: false
protocol: http
host: localhost
port: 8080
# Perspective keyservers to use as a backup when direct key fetches fail. This may
# be required to satisfy key requests for servers that are no longer online when
# joining some rooms.

4
go.mod
View file

@ -1,6 +1,6 @@
module github.com/matrix-org/dendrite
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
@ -45,7 +45,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.10
github.com/morikuni/aec v1.0.0 // indirect
github.com/nats-io/nats-server/v2 v2.3.2
github.com/nats-io/nats.go v1.13.1-0.20211122170419-d7c1d78a50fc
github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31

15
go.sum
View file

@ -1122,8 +1122,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nats-io/jwt/v2 v2.2.0 h1:Yg/4WFK6vsqMudRg91eBb7Dh6XeVcDMPHycDE8CfltE=
github.com/nats-io/jwt/v2 v2.2.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY=
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -1132,8 +1132,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 h1:BLQVdjMH5XD4BYb0fa+c2Oh2Nr1vrO7GKvRnIJDxChc=
github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423/go.mod h1:9sdEkBhyZMQG1M9TevnlYUwMusRACn2vlgOeqoHKwVo=
github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8=
github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
@ -1508,8 +1508,8 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1735,6 +1735,7 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc=
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -1756,10 +1757,10 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View file

@ -7,14 +7,6 @@ import (
)
const (
RoomServerStateKeyNIDsCacheName = "roomserver_statekey_nids"
RoomServerStateKeyNIDsCacheMaxEntries = 1024
RoomServerStateKeyNIDsCacheMutable = false
RoomServerEventTypeNIDsCacheName = "roomserver_eventtype_nids"
RoomServerEventTypeNIDsCacheMaxEntries = 64
RoomServerEventTypeNIDsCacheMutable = false
RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false
@ -29,44 +21,10 @@ type RoomServerCaches interface {
// RoomServerNIDsCache contains the subset of functions needed for
// a roomserver NID cache.
type RoomServerNIDsCache interface {
GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool)
StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID)
GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool)
StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID)
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
}
func (c Caches) GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) {
val, found := c.RoomServerStateKeyNIDs.Get(stateKey)
if found && val != nil {
if stateKeyNID, ok := val.(types.EventStateKeyNID); ok {
return stateKeyNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) {
c.RoomServerStateKeyNIDs.Set(stateKey, nid)
}
func (c Caches) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) {
val, found := c.RoomServerEventTypeNIDs.Get(eventType)
if found && val != nil {
if eventTypeNID, ok := val.(types.EventTypeNID); ok {
return eventTypeNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) {
c.RoomServerEventTypeNIDs.Set(eventType, nid)
}
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID)))
if found && val != nil {

View file

@ -4,14 +4,12 @@ package caching
// different implementations as long as they satisfy the Cache
// interface.
type Caches struct {
RoomVersions Cache // RoomVersionCache
ServerKeys Cache // ServerKeyCache
RoomServerStateKeyNIDs Cache // RoomServerNIDsCache
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache
FederationEvents Cache // FederationEventsCache
RoomVersions Cache // RoomVersionCache
ServerKeys Cache // ServerKeyCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache
FederationEvents Cache // FederationEventsCache
}
// Cache is the interface that an implementation must satisfy.

View file

@ -28,24 +28,6 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil {
return nil, err
}
roomServerStateKeyNIDs, err := NewInMemoryLRUCachePartition(
RoomServerStateKeyNIDsCacheName,
RoomServerStateKeyNIDsCacheMutable,
RoomServerStateKeyNIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerEventTypeNIDs, err := NewInMemoryLRUCachePartition(
RoomServerEventTypeNIDsCacheName,
RoomServerEventTypeNIDsCacheMutable,
RoomServerEventTypeNIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomIDsCacheName,
RoomServerRoomIDsCacheMutable,
@ -74,18 +56,15 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
return nil, err
}
go cacheCleaner(
roomVersions, serverKeys, roomServerStateKeyNIDs,
roomServerEventTypeNIDs, roomServerRoomIDs,
roomVersions, serverKeys, roomServerRoomIDs,
roomInfos, federationEvents,
)
return &Caches{
RoomVersions: roomVersions,
ServerKeys: serverKeys,
RoomServerStateKeyNIDs: roomServerStateKeyNIDs,
RoomServerEventTypeNIDs: roomServerEventTypeNIDs,
RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos,
FederationEvents: federationEvents,
RoomVersions: roomVersions,
ServerKeys: serverKeys,
RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos,
FederationEvents: federationEvents,
}, nil
}

View file

@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database)
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
cfg.EDUServer.InternalAPI.Listen = assignAddress()

View file

@ -367,10 +367,13 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
waitTime = fcerr.RetryAfter
} else if fcerr.Blacklisted {
waitTime = time.Hour * 8
} else {
// For all other errors (DNS resolution, network etc.) wait 1 hour.
waitTime = time.Hour
}
} else {
waitTime = time.Hour
logger.WithError(err).Warn("GetUserDevices returned unknown error type")
logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type")
}
continue
}

View file

@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
}
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domain := string(serverName)
// query local devices
if serverName == a.ThisServer {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
@ -513,6 +513,11 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// drop the error as it's already a failure at this point
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
}
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
if len(res.DeviceKeys) > 0 {
delete(res.Failures, serverName)
}
respMu.Unlock()
}
@ -520,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
@ -549,10 +554,58 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
}
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
// get a list of devices from the user API that actually exist, as
// we won't store keys for devices that don't exist
uapidevices := &userapi.QueryDevicesResponse{}
if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
return
}
if !uapidevices.UserExists {
res.Error = &api.KeyError{
Err: fmt.Sprintf("user %q does not exist", req.UserID),
}
return
}
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
for _, key := range uapidevices.Devices {
existingDeviceMap[key.ID] = struct{}{}
}
// Get all of the user existing device keys so we can check for changes.
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
}
return
}
// Work out whether we have device keys in the keyserver for devices that
// no longer exist in the user API. This is mostly an exercise to ensure
// that we keep some integrity between the two.
var toClean []gomatrixserverlib.KeyID
for _, k := range existingKeys {
if _, ok := existingDeviceMap[k.DeviceID]; !ok {
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
}
}
if len(toClean) > 0 {
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
} else {
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
}
}
var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
var serverName gomatrixserverlib.ServerName
_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
if err != nil {
continue // ignore invalid users
}
@ -563,6 +616,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking
}
// check that the device in question actually exists in the user
// API before we try and store a key for it
if _, ok := existingDeviceMap[key.DeviceID]; !ok {
continue
}
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
@ -578,29 +636,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
})
}
// get existing device keys so we can check for changes
existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore {
existingKeys[i] = api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID,
},
}
}
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
}
return
}
if req.OnlyDisplayNameUpdates {
// add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore)
}
// store the device keys and emit changes
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),

View file

@ -53,7 +53,7 @@ type Database interface {
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.

View file

@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}

View file

@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
})
}
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
}
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {

View file

@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}

View file

@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
}
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}

View file

@ -38,7 +38,7 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}

View file

@ -3,9 +3,11 @@ package api
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
// RoomserverInputAPI is used to write events to the room server.
@ -14,6 +16,7 @@ type RoomserverInternalAPI interface {
// interdependencies between the roomserver and other input APIs
SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing)
SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI)
SetUserAPI(userAPI userapi.UserInternalAPI)
InputRoomEvents(
ctx context.Context,

View file

@ -5,10 +5,12 @@ import (
"encoding/json"
"fmt"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
// RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the
@ -25,6 +27,10 @@ func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceQuer
t.Impl.SetAppserviceAPI(asAPI)
}
func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.UserInternalAPI) {
t.Impl.SetUserAPI(userAPI)
}
func (t *RoomserverInternalAPITrace) InputRoomEvents(
ctx context.Context,
req *InputRoomEventsRequest,

View file

@ -95,6 +95,8 @@ type PerformLeaveRequest struct {
}
type PerformLeaveResponse struct {
Code int `json:"code,omitempty"`
Message interface{} `json:"message,omitempty"`
}
type PerformInviteRequest struct {

View file

@ -14,6 +14,8 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
@ -32,6 +34,7 @@ type RoomserverInternalAPI struct {
*perform.Publisher
*perform.Backfiller
*perform.Forgetter
ProcessContext *process.ProcessContext
DB storage.Database
Cfg *config.RoomServer
Cache caching.RoomServerCaches
@ -48,12 +51,13 @@ type RoomserverInternalAPI struct {
}
func NewRoomserverAPI(
cfg *config.RoomServer, roomserverDB storage.Database, consumer nats.JetStreamContext,
inputRoomEventTopic, outputRoomEventTopic string, caches caching.RoomServerCaches,
perspectiveServerNames []gomatrixserverlib.ServerName,
processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database,
consumer nats.JetStreamContext, inputRoomEventTopic, outputRoomEventTopic string,
caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName,
) *RoomserverInternalAPI {
serverACLs := acls.NewServerACLs(roomserverDB)
a := &RoomserverInternalAPI{
ProcessContext: processCtx,
DB: roomserverDB,
Cfg: cfg,
Cache: caches,
@ -83,6 +87,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
r.KeyRing = keyRing
r.Inputer = &input.Inputer{
ProcessContext: r.ProcessContext,
DB: r.DB,
InputRoomEventTopic: r.InputRoomEventTopic,
OutputRoomEventTopic: r.OutputRoomEventTopic,
@ -155,6 +160,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
}
}
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) {
r.Leaver.UserAPI = userAPI
}
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) {
r.asAPI = asAPI
}

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/prometheus/client_golang/prometheus"
@ -59,6 +60,7 @@ var keyContentFields = map[string]string{
}
type Inputer struct {
ProcessContext *process.ProcessContext
DB storage.Database
JetStream nats.JetStreamContext
Durable nats.SubOpt
@ -115,7 +117,7 @@ func (r *Inputer) Start() error {
_ = msg.InProgress() // resets the acknowledgement wait timer
defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent)
action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err)

View file

@ -405,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro
if len(extraEventIDs) == 0 {
return nil, nil
}
extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs)
extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs)
if err != nil {
return nil, err
}

View file

@ -16,25 +16,29 @@ package perform
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type Leaver struct {
Cfg *config.RoomServer
DB storage.Database
FSAPI fsAPI.FederationInternalAPI
Cfg *config.RoomServer
DB storage.Database
FSAPI fsAPI.FederationInternalAPI
UserAPI userapi.UserInternalAPI
Inputer *input.Inputer
}
@ -85,6 +89,31 @@ func (r *Leaver) performLeaveRoomByID(
if host != r.Cfg.Matrix.ServerName {
return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID)
}
// check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{}
if err := r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
UserID: req.UserID,
RoomID: req.RoomID,
DataType: "m.tag",
}, accData); err != nil {
return nil, fmt.Errorf("unable to query account data")
}
if roomData, ok := accData.RoomAccountData[req.RoomID]; ok {
tagData, ok := roomData["m.tag"]
if ok {
tags := gomatrix.TagContent{}
if err = json.Unmarshal(tagData, &tags); err != nil {
return nil, fmt.Errorf("unable to unmarshal tag content")
}
if _, ok = tags.Tags["m.server_notice"]; ok {
// mimic the returned values from Synapse
res.Message = "You cannot reject this invite"
res.Code = 403
return nil, fmt.Errorf("You cannot reject this invite")
}
}
}
}
// There's no invite pending, so first of all we want to find out

View file

@ -11,6 +11,8 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
)
@ -90,6 +92,10 @@ func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.Federation
func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) {
}
// SetUserAPI no-ops in HTTP client mode as there is no chicken/egg scenario
func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) {
}
// SetRoomAlias implements RoomserverAliasAPI
func (h *httpRoomserverInternalAPI) SetRoomAlias(
ctx context.Context,

View file

@ -53,7 +53,7 @@ func NewInternalAPI(
js := jetstream.Prepare(&cfg.Matrix.JetStream)
return internal.NewRoomserverAPI(
cfg, roomserverDB, js,
base.ProcessContext, cfg, roomserverDB, js,
cfg.Matrix.JetStream.TopicFor(jetstream.InputRoomEvent),
cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent),
base.Caches, perspectiveServerNames,

View file

@ -127,6 +127,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)"
const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE"
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
@ -147,6 +150,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
bulkSelectUnsentEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
@ -173,6 +177,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
}.Prepare(db)
@ -458,10 +463,28 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
return results, nil
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
}
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
var stmt *sql.Stmt
if onlyUnsent {
stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt)
} else {
stmt = sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err

View file

@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
}
// Look up the NID of the new join event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID})
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
}
// Look up the NID of the new leave event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID})
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er
}
if u.membership != tables.MembershipStateKnock {
// Look up the NID of the new knock event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()})
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}

View file

@ -215,7 +215,13 @@ func (u *RoomUpdater) EventIDs(
func (u *RoomUpdater) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return u.d.eventNIDs(ctx, u.txn, eventIDs)
return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter)
}
func (u *RoomUpdater) UnsentEventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly)
}
func (u *RoomUpdater) StateAtEventIDs(
@ -231,7 +237,11 @@ func (u *RoomUpdater) StateEntriesForEventIDs(
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
}
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(

View file

@ -59,23 +59,12 @@ func (d *Database) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID)
remaining := []string{}
for _, eventType := range eventTypes {
if nid, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
result[eventType] = nid
} else {
remaining = append(remaining, eventType)
}
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes)
if err != nil {
return nil, err
}
if len(remaining) > 0 {
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
if err != nil {
return nil, err
}
for eventType, nid := range nids {
result[eventType] = nid
d.Cache.StoreRoomServerEventTypeNID(eventType, nid)
}
for eventType, nid := range nids {
result[eventType] = nid
}
return result, nil
}
@ -96,23 +85,12 @@ func (d *Database) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID)
remaining := []string{}
for _, eventStateKey := range eventStateKeys {
if nid, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
result[eventStateKey] = nid
} else {
remaining = append(remaining, eventStateKey)
}
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
if err != nil {
return nil, err
}
if len(remaining) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
if err != nil {
return nil, err
}
for eventStateKey, nid := range nids {
result[eventStateKey] = nid
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, nid)
}
for eventStateKey, nid := range nids {
result[eventStateKey] = nid
}
return result, nil
}
@ -238,13 +216,27 @@ func (d *Database) addState(
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.eventNIDs(ctx, nil, eventIDs)
return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
}
type UnsentFilter bool
const (
NoFilter UnsentFilter = false
FilterUnsentOnly UnsentFilter = true
)
func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
switch filter {
case FilterUnsentOnly:
return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs)
case NoFilter:
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
default:
panic("impossible case")
}
}
func (d *Database) SetState(
@ -281,11 +273,11 @@ func (d *Database) EventIDs(
}
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return d.eventsFromIDs(ctx, nil, eventIDs)
return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter)
}
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil {
return nil, err
}
@ -704,9 +696,6 @@ func (d *Database) assignRoomNID(
func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
if eventTypeNID, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
return eventTypeNID, nil
}
// Check if we already have a numeric ID in the database.
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
@ -717,18 +706,12 @@ func (d *Database) assignEventTypeNID(
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
}
}
if err == nil {
d.Cache.StoreRoomServerEventTypeNID(eventType, eventTypeNID)
}
return eventTypeNID, err
}
func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
if eventStateKeyNID, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
return eventStateKeyNID, nil
}
// Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
@ -739,9 +722,6 @@ func (d *Database) assignStateKeyNID(
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
}
}
if err == nil {
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, eventStateKeyNID)
}
return eventStateKeyNID, err
}

View file

@ -99,6 +99,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)"
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
@ -118,8 +121,9 @@ type eventStatements struct {
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
//bulkSelectEventNIDStmt *sql.Stmt
//bulkSelectUnsentEventNIDStmt *sql.Stmt
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
func createEventsTable(db *sql.DB) error {
@ -144,7 +148,8 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
//{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
//{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
}.Prepare(db)
}
@ -494,15 +499,33 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
return results, nil
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
}
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
iEventIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
var selectOrig string
if onlyUnsent {
selectOrig = strings.Replace(bulkSelectUnsentEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
} else {
selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
}
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err

View file

@ -59,6 +59,7 @@ type Events interface {
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
}

View file

@ -39,7 +39,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/gorilla/mux"
@ -274,8 +274,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
db, err := userdb.NewDatabase(
&b.Cfg.UserAPI.AccountDatabase,
b.Cfg.Global.ServerName,
b.Cfg.UserAPI.BCryptCost,
b.Cfg.UserAPI.OpenIDTokenLifetimeMS,
userapi.DefaultLoginTokenLifetime,
)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db")
}

View file

@ -18,6 +18,10 @@ type ClientAPI struct {
// If set, allows registration by anyone who also has the shared
// secret, even if registration is otherwise disabled.
RegistrationSharedSecret string `yaml:"registration_shared_secret"`
// If set, prevents guest accounts from being created. Only takes
// effect if registration is enabled, otherwise guests registration
// is forbidden either way.
GuestsDisabled bool `yaml:"guests_disabled"`
// Boolean stating whether catpcha registration is enabled
// and required

View file

@ -29,8 +29,6 @@ type FederationAPI struct {
// on remote federation endpoints. This is not recommended in production!
DisableTLSValidation bool `yaml:"disable_tls_validation"`
Proxy Proxy `yaml:"proxy_outbound"`
// Perspective keyservers, to use as a backup when direct key fetch
// requests don't succeed
KeyPerspectives KeyPerspectives `yaml:"key_perspectives"`
@ -50,8 +48,6 @@ func (c *FederationAPI) Defaults(generate bool) {
c.FederationMaxRetries = 16
c.DisableTLSValidation = false
c.Proxy.Defaults()
}
func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {

View file

@ -65,6 +65,9 @@ type Global struct {
// DNS caching options for all outbound HTTP requests
DNSCache DNSCacheOptions `yaml:"dns_cache"`
// ServerNotices configuration used for sending server notices
ServerNotices ServerNotices `yaml:"server_notices"`
// Consent tracking options
UserConsentOptions UserConsentOptions `yaml:"user_consent"`
}
@ -84,6 +87,7 @@ func (c *Global) Defaults(generate bool) {
c.DNSCache.Defaults()
c.Sentry.Defaults()
c.UserConsentOptions.Defaults(c.BaseURL)
c.ServerNotices.Defaults(generate)
}
func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
@ -95,6 +99,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.Sentry.Verify(configErrs, isMonolith)
c.DNSCache.Verify(configErrs, isMonolith)
c.UserConsentOptions.Verify(configErrs, isMonolith)
c.ServerNotices.Verify(configErrs, isMonolith)
}
type OldVerifyKeys struct {
@ -136,6 +141,31 @@ func (c *Metrics) Defaults(generate bool) {
func (c *Metrics) Verify(configErrs *ConfigErrors, isMonolith bool) {
}
// ServerNotices defines the configuration used for sending server notices
type ServerNotices struct {
Enabled bool `yaml:"enabled"`
// The localpart to be used when sending notices
LocalPart string `yaml:"local_part"`
// The displayname to be used when sending notices
DisplayName string `yaml:"display_name"`
// The avatar of this user
AvatarURL string `yaml:"avatar"`
// The roomname to be used when creating messages
RoomName string `yaml:"room_name"`
}
func (c *ServerNotices) Defaults(generate bool) {
if generate {
c.Enabled = true
c.LocalPart = "_server"
c.DisplayName = "Server Alert"
c.RoomName = "Server Alert"
c.AvatarURL = ""
}
}
func (c *ServerNotices) Verify(errors *ConfigErrors, isMonolith bool) {}
// The configuration to use for Sentry error reporting
type Sentry struct {
Enabled bool `yaml:"enabled"`

View file

@ -58,6 +58,11 @@ global:
basic_auth:
username: metrics
password: metrics
server_notices:
local_part: "_server"
display_name: "Server alerts"
avatar: ""
room_name: "Server Alerts"
app_service_api:
internal_api:
listen: http://localhost:7777
@ -118,11 +123,6 @@ federation_sender:
conn_max_lifetime: -1
send_max_retries: 16
disable_tls_validation: false
proxy_outbound:
enabled: false
protocol: http
host: localhost
port: 8080
key_server:
internal_api:
listen: http://localhost:7779

View file

@ -16,9 +16,6 @@ type UserAPI struct {
// The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"`
// The Device database stores session information for the devices of logged
// in local users. It is accessed by the UserAPI.
DeviceDatabase DatabaseOptions `yaml:"device_database"`
}
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) {
c.InternalAPI.Listen = "http://localhost:7781"
c.InternalAPI.Connect = "http://localhost:7781"
c.AccountDatabase.Defaults(10)
c.DeviceDatabase.Defaults(10)
if generate {
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
}
c.BCryptCost = bcrypt.DefaultCost
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
}

View file

@ -24,13 +24,12 @@ func Prepare(cfg *config.JetStream) natsclient.JetStreamContext {
if natsServer == nil {
var err error
natsServer, err = natsserver.NewServer(&natsserver.Options{
ServerName: "monolith",
DontListen: true,
JetStream: true,
StoreDir: string(cfg.StoragePath),
NoSystemAccount: true,
AllowNewAccounts: false,
MaxPayload: 16 * 1024 * 1024,
ServerName: "monolith",
DontListen: true,
JetStream: true,
StoreDir: string(cfg.StoragePath),
NoSystemAccount: true,
MaxPayload: 16 * 1024 * 1024,
})
if err != nil {
panic(err)

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@ -38,7 +38,7 @@ import (
// all components of Dendrite, for use in monolith mode.
type Monolith struct {
Config *config.Dendrite
AccountDB accounts.Database
AccountDB userdb.Database
KeyRing *gomatrixserverlib.KeyRing
Client *gomatrixserverlib.Client
FedClient *gomatrixserverlib.FederationClient

View file

@ -16,6 +16,7 @@ package consumers
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@ -307,7 +308,9 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
ctx context.Context, msg api.OutputRetireInviteEvent,
) {
pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID)
if err != nil {
// It's possible we just haven't heard of this invite yet, so
// we should not panic if we try to retire it.
if err != nil && err != sql.ErrNoRows {
sentry.CaptureException(err)
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{

View file

@ -39,14 +39,14 @@ func Setup(
rsAPI api.RoomserverInternalAPI,
cfg *config.SyncAPI,
) {
r0mux := csMux.PathPrefix("/r0").Subrouter()
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
// TODO: Add AS support for all handlers below.
r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -54,7 +54,7 @@ func Setup(
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter",
v3mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -64,7 +64,7 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}",
v3mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
@ -74,7 +74,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingKeyChangeRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions)
}

View file

@ -279,7 +279,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
parts := strings.Split(tok[1:], "_")
var positions [7]StreamPosition
for i, p := range parts {
if i > len(positions) {
if i >= len(positions) {
break
}
var pos int

View file

@ -592,3 +592,4 @@ Forward extremities remain so even after the next events are populated as outlie
If a device list update goes missing, the server resyncs on the next one
uploading self-signing key notifies over federation
uploading signed devices gets propagated over federation
Device list doesn't change if remote server is down

View file

@ -18,8 +18,9 @@ import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
// UserInternalAPI is the internal API for information about users and devices.
@ -384,6 +385,7 @@ type Device struct {
// If the device is for an appservice user,
// this is the appservice ID.
AppserviceID string
AccountType AccountType
}
// Account represents a Matrix account on this home server.
@ -392,7 +394,7 @@ type Account struct {
Localpart string
ServerName gomatrixserverlib.ServerName
AppServiceID string
// TODO: Other flags like IsAdmin, IsGuest
AccountType AccountType
// TODO: Associations (e.g. with application services)
}
@ -448,4 +450,8 @@ const (
AccountTypeUser AccountType = 1
// AccountTypeGuest indicates this is a guest account
AccountTypeGuest AccountType = 2
// AccountTypeAdmin indicates this is an admin account
AccountTypeAdmin AccountType = 3
// AccountTypeAppService indicates this is an appservice account
AccountTypeAppService AccountType = 4
)

View file

@ -19,6 +19,13 @@ import (
"time"
)
// DefaultLoginTokenLifetime determines how old a valid token may be.
//
// NOTSPEC: The current spec says "SHOULD be limited to around five
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
const DefaultLoginTokenLifetime = 2 * time.Minute
type LoginTokenInternalAPI interface {
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error

View file

@ -21,22 +21,21 @@ import (
"errors"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/userapi/storage"
)
type UserInternalAPI struct {
AccountDB accounts.Database
DeviceDB devices.Database
DB storage.Database
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
@ -54,10 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if req.AccountType == api.AccountTypeGuest {
acc, err := a.AccountDB.CreateGuestAccount(ctx)
if err != nil {
@ -86,11 +86,18 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
Localpart: req.Localpart,
ServerName: a.ServerName,
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
AccountType: req.AccountType,
}
return nil
}
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
if req.AccountType == api.AccountTypeGuest {
res.AccountCreated = true
res.Account = acc
return nil
}
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@ -100,7 +107,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
@ -113,7 +120,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@ -138,12 +145,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
}
if err != nil {
return err
@ -197,7 +204,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
@ -209,7 +216,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@ -224,7 +231,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@ -262,7 +269,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
}
prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local)
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil {
if err == sql.ErrNoRows {
return nil
@ -276,7 +283,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
}
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
@ -285,7 +292,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
@ -313,10 +320,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
}
devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local)
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil {
return err
}
res.UserExists = true
res.Devices = devs
return nil
}
@ -331,7 +339,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil {
return err
}
@ -349,7 +357,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
global, rooms, err := a.AccountDB.GetAccountData(ctx, local)
global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil {
return err
}
@ -368,13 +376,22 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken)
device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil {
if err == sql.ErrNoRows {
return nil
}
return err
}
localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return err
}
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
device.AccountType = acc.AccountType
res.Device = device
return nil
}
@ -401,6 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// AS dummy device has AS's token.
AccessToken: token,
AppserviceID: appService.ID,
AccountType: api.AccountTypeAppService,
}
localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName)
@ -410,7 +428,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart)
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -428,7 +446,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
err := a.AccountDB.DeactivateAccount(ctx, req.Localpart)
err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil
return err
}
@ -437,7 +455,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{
Token: token,
@ -450,7 +468,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil {
return err
}
@ -472,7 +490,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
return nil
}
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
}
@ -485,7 +503,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Create metadata
if req.Version == "" {
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err)
}
@ -498,7 +516,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Update metadata
if len(req.Keys.Rooms) == 0 {
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
@ -519,7 +537,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err)
return
@ -547,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
})
}
}
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return
@ -557,7 +575,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
@ -573,14 +591,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted
if !req.ReturnKeys {
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
}
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return

View file

@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
}
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data)
tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil {
return err
}
@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
// PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
return a.DeviceDB.RemoveLoginToken(ctx, req.Token)
return a.DB.RemoveLoginToken(ctx, req.Token)
}
// QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token)
tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil {
res.Data = nil
if err == sql.ErrNoRows {
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
}
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil {
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil

View file

@ -1,547 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
deltas.LoadAddPolicyVersion(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, sqlutil.ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string,
) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string,
) (*api.Account, error) {
var account *api.Account
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = d.hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists
}
return nil, err
}
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)); err != nil {
return nil, err
}
return account, nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart)
}
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
if err != nil {
return err
}
if deleted {
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
if err != nil {
return err
}
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
}
}
// if we shouldn't replace the key we do nothing with it
continue
}
}
// if we're here, either the room or session are new, either way, we insert
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
}
}
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
}
// GetOutdatedPolicy queries all users which didn't accept the current policy version.
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
}
// UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
})
return
}

View file

@ -1,52 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}

View file

@ -1,270 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,271 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
writer sqlutil.Writer
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, writer, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,42 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters. loginTokenLifetime determines how long a
// login token from CreateLoginToken is valid.
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,39 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(
dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName,
loginTokenLifetime time.Duration,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
package storage
import (
"context"
@ -32,8 +32,7 @@ type Database interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string) (*api.Account, error)
CreateGuestAccount(ctx context.Context) (*api.Account, error)
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, policyVersion string, accountType api.AccountType) (*api.Account, error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given
@ -64,6 +63,35 @@ type Database interface {
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const accountDataSchema = `
@ -56,19 +57,20 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(accountDataSchema)
func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
s := &accountDataStatements{}
_, err := db.Exec(accountDataSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertAccountDataStmt, insertAccountDataSQL},
{&s.selectAccountDataStmt, selectAccountDataSQL},
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
}.Prepare(db)
}
func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData(
return
}
func (s *accountDataStatements) selectAccountData(
func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string,
) (
/* global */ map[string]json.RawMessage,
@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
return global, rooms, rows.Err()
}
func (s *accountDataStatements) selectAccountDataByType(
func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte

View file

@ -19,10 +19,12 @@ import (
"database/sql"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
)
@ -40,17 +42,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
appservice_id TEXT,
-- If the account is currently active
is_deactivated BOOLEAN DEFAULT FALSE,
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type SMALLINT NOT NULL,
-- The policy version this user has accepted
policy_version TEXT
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
-- upgraded_ts, devices, any email reset stuff?
);
-- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -59,7 +63,7 @@ const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
@ -89,14 +93,15 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) execSchema(db *sql.DB) error {
func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
s := &accountsStatements{
serverName: serverName,
}
_, err := db.Exec(accountsSchema)
return err
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.serverName = server
return sqlutil.StatementList{
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL},
@ -112,17 +117,17 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string,
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error
if appserviceID == "" {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, policyVersion)
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, "")
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "")
}
if err != nil {
return nil, err
@ -133,38 +138,39 @@ func (s *accountsStatements) insertAccount(
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
AppServiceID: appserviceID,
AccountType: accountType,
}, nil
}
func (s *accountsStatements) updatePassword(
func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) deactivateAccount(
func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
}
func (s *accountsStatements) selectAccountByLocalpart(
func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@ -181,7 +187,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil
}
func (s *accountsStatements) selectNewNumericLocalpart(
func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt

View file

@ -4,12 +4,14 @@ import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive)
goose.AddMigration(UpAddAccountType, DownAddAccountType)
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
}

View file

@ -5,13 +5,8 @@ import (
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}

View file

@ -0,0 +1,34 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadAddAccountType(m *sqlutil.Migrations) {
m.AddMigration(UpAddAccountType, DownAddAccountType)
}
func UpAddAccountType(tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.Exec(`ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddAccountType(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN account_type;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@ -111,50 +112,32 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) execSchema(db *sql.DB) error {
func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
serverName: serverName,
}
_, err := db.Exec(devicesSchema)
return err
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
if err != nil {
return nil, err
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL},
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
{&s.deleteDeviceStmt, deleteDeviceSQL},
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
{&s.deleteDevicesStmt, deleteDevicesSQL},
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
}.Prepare(db)
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
@ -176,7 +159,7 @@ func (s *devicesStatements) insertDevice(
}
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice(
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -186,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices(
func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
@ -196,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -204,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err
}
func (s *devicesStatements) updateDeviceName(
func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -212,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
return err
}
func (s *devicesStatements) selectDeviceByToken(
func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
var dev api.Device
@ -228,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
@ -245,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil {
return nil, err
@ -268,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err()
}
func (s *devicesStatements) selectDevicesByLocalpart(
func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
@ -310,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, rows.Err()
}
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const keyBackupTableSchema = `
@ -72,12 +73,13 @@ type keyBackupStatements struct {
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
}
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
s := &keyBackupStatements{}
_, err := db.Exec(keyBackupTableSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s keyBackupStatements) countKeys(
func (s keyBackupStatements) CountKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return
}
func (s *keyBackupStatements) insertBackupKey(
func (s *keyBackupStatements) InsertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
return
}
func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) UpdateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
return
}
func (s *keyBackupStatements) selectKeys(
func (s *keyBackupStatements) SelectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomID(
func (s *keyBackupStatements) SelectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)

View file

@ -22,6 +22,7 @@ import (
"strconv"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const keyBackupVersionTableSchema = `
@ -69,12 +70,13 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt
}
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
s := &keyBackupVersionStatements{}
_, err := db.Exec(keyBackupVersionTableSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *keyBackupVersionStatements) insertKeyBackup(
func (s *keyBackupVersionStatements) InsertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) {
var versionInt int64
@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return strconv.FormatInt(versionInt, 10), err
}
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err
}
func (s *keyBackupVersionStatements) updateKeyBackupETag(
func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup(
func (s *keyBackupVersionStatements) DeleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64)
@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
return ra == 1, nil
}
func (s *keyBackupVersionStatements) selectKeyBackup(
func (s *keyBackupVersionStatements) SelectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64

View file

@ -21,18 +21,11 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/util"
)
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt
}
// execSchema ensures tables and indices exist.
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
@ -45,21 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens (
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`)
return err
`
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectStmt *sql.Stmt
}
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
s := &loginTokenStatements{}
_, err := db.Exec(loginTokenSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertStmt, insertLoginTokenSQL},
{&s.deleteStmt, deleteLoginTokenSQL},
{&s.selectStmt, selectLoginTokenSQL},
}.Prepare(db)
}
// insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err
@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil {
@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
}
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil {
return nil, err
}

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
);
`
const insertTokenSQL = "" +
const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct {
type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return
func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
s := &openIDTokenStatements{
serverName: serverName,
}
s.serverName = server
return sqlutil.StatementList{
{&s.insertTokenStmt, insertTokenSQL},
{&s.selectTokenStmt, selectTokenSQL},
_, err := db.Exec(openIDTokenSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertTokenStmt, insertOpenIDTokenSQL},
{&s.selectTokenStmt, selectOpenIDTokenSQL},
}.Prepare(db)
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken(
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken(
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributes(
func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
const profilesSchema = `
@ -59,12 +60,13 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(profilesSchema)
func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
s := &profilesStatements{}
_, err := db.Exec(profilesSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.insertProfileStmt, insertProfileSQL},
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
{&s.setAvatarURLStmt, setAvatarURLSQL},
@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *profilesStatements) insertProfile(
func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return
}
func (s *profilesStatements) selectProfileByLocalpart(
func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart(
return &profile, nil
}
func (s *profilesStatements) setAvatarURL(
ctx context.Context, localpart string, avatarURL string,
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
}
func (s *profilesStatements) setDisplayName(
ctx context.Context, localpart string, displayName string,
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
}
func (s *profilesStatements) selectProfilesBySearch(
func (s *profilesStatements) SelectProfilesBySearch(
ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile

View file

@ -0,0 +1,105 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
if _, err = db.Exec(accountsSchema); err != nil {
// do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err
}
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
}
devicesTable, err := NewPostgresDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
}
keyBackupTable, err := NewPostgresKeyBackupTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err)
}
keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err)
}
loginTokenTable, err := NewPostgresLoginTokenTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err)
}
openIDTable, err := NewPostgresOpenIDTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err)
}
profilesTable, err := NewPostgresProfilesTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err)
}
threePIDTable, err := NewPostgresThreePIDTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
Devices: devicesTable,
KeyBackups: keyBackupTable,
KeyBackupVersions: keyBackupVersionTable,
LoginTokens: loginTokenTable,
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewDummyWriter(),
LoginTokenLifetime: loginTokenLifetime,
BcryptCost: bcryptCost,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}

Some files were not shown because too many files have changed in this diff Show more