Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking

This commit is contained in:
Till Faelligen 2022-02-21 12:08:03 +01:00
commit 9c3a1cfd47
122 changed files with 2344 additions and 2320 deletions

View file

@ -23,7 +23,7 @@ import (
"errors" "errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -85,7 +85,7 @@ func RetrieveUserProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
asAPI AppServiceQueryAPI, asAPI AppServiceQueryAPI,
accountDB accounts.Database, accountDB userdb.Database,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {

View file

@ -22,6 +22,8 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/sirupsen/logrus"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/consumers"
"github.com/matrix-org/dendrite/appservice/inthttp" "github.com/matrix-org/dendrite/appservice/inthttp"
@ -34,7 +36,6 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/sirupsen/logrus"
) )
// AddInternalRoutes registers HTTP handlers for internal API calls // AddInternalRoutes registers HTTP handlers for internal API calls
@ -121,7 +122,7 @@ func generateAppServiceAccount(
) error { ) error {
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
AppServiceID: as.ID, AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,

View file

@ -283,8 +283,7 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))

View file

@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))

View file

@ -28,7 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -38,7 +38,7 @@ func AddPublicRoutes(
synapseAdminRouter *mux.Router, synapseAdminRouter *mux.Router,
consentAPIMux *mux.Router, consentAPIMux *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountsDB accounts.Database, accountsDB userdb.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
eduInputAPI eduServerAPI.EDUServerInputAPI, eduInputAPI eduServerAPI.EDUServerInputAPI,

View file

@ -156,6 +156,15 @@ func MissingParam(msg string) *MatrixError {
return &MatrixError{"M_MISSING_PARAM", msg} return &MatrixError{"M_MISSING_PARAM", msg}
} }
// LeaveServerNoticeError is an error returned when trying to reject an invite
// for a server notice room.
func LeaveServerNoticeError() *MatrixError {
return &MatrixError{
ErrCode: "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM",
Err: "You cannot reject this invite",
}
}
type IncompatibleRoomVersionError struct { type IncompatibleRoomVersionError struct {
RoomVersion string `json:"room_version"` RoomVersion string `json:"room_version"`
Error string `json:"error"` Error string `json:"error"`

View file

@ -47,8 +47,8 @@ func GetAdminWhois(
req *http.Request, userAPI api.UserInternalAPI, device *api.Device, req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string, userID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { allowed := device.AccountType == api.AccountTypeAdmin || userID == device.UserID
// TODO: Still allow if user is admin if !allowed {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("userID does not match the current user"), JSON: jsonerror.Forbidden("userID does not match the current user"),

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -30,7 +31,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -137,36 +138,17 @@ type fledglingEvent struct {
func CreateRoom( func CreateRoom(
req *http.Request, device *api.Device, req *http.Request, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs?
roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
return createRoom(req, device, cfg, roomID, accountDB, rsAPI, asAPI)
}
// createRoom implements /createRoom
// nolint: gocyclo
func createRoom(
req *http.Request, device *api.Device,
cfg *config.ClientAPI, roomID string,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
userID := device.UserID
var r createRoomRequest var r createRoomRequest
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
// TODO: apply rate-limit
if resErr = r.Validate(); resErr != nil { if resErr = r.Validate(); resErr != nil {
return *resErr return *resErr
} }
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -174,6 +156,25 @@ func createRoom(
JSON: jsonerror.InvalidArgumentValue(err.Error()), JSON: jsonerror.InvalidArgumentValue(err.Error()),
} }
} }
return createRoom(req.Context(), r, device, cfg, accountDB, rsAPI, asAPI, evTime)
}
// createRoom implements /createRoom
// nolint: gocyclo
func createRoom(
ctx context.Context,
r createRoomRequest, device *api.Device,
cfg *config.ClientAPI,
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
evTime time.Time,
) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs?
roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
logger := util.GetLogger(ctx)
userID := device.UserID
// Clobber keys: creator, room_version // Clobber keys: creator, room_version
@ -200,16 +201,16 @@ func createRoom(
"roomVersion": roomVersion, "roomVersion": roomVersion,
}).Info("Creating new room") }).Info("Creating new room")
profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
createContent := map[string]interface{}{} createContent := map[string]interface{}{}
if len(r.CreationContent) > 0 { if len(r.CreationContent) > 0 {
if err = json.Unmarshal(r.CreationContent, &createContent); err != nil { if err = json.Unmarshal(r.CreationContent, &createContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for creation_content failed") util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("invalid create content"), JSON: jsonerror.BadJSON("invalid create content"),
@ -230,7 +231,7 @@ func createRoom(
// Merge powerLevelContentOverride fields by unmarshalling it atop the defaults // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults
err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent) err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for power_level_content_override failed") util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("malformed power_level_content_override"), JSON: jsonerror.BadJSON("malformed power_level_content_override"),
@ -319,9 +320,9 @@ func createRoom(
} }
var aliasResp roomserverAPI.GetRoomIDForAliasResponse var aliasResp roomserverAPI.GetRoomIDForAliasResponse
err = rsAPI.GetRoomIDForAlias(req.Context(), &hasAliasReq, &aliasResp) err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if aliasResp.RoomID != "" { if aliasResp.RoomID != "" {
@ -426,7 +427,7 @@ func createRoom(
} }
err = builder.SetContent(e.Content) err = builder.SetContent(e.Content)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if i > 0 { if i > 0 {
@ -435,12 +436,12 @@ func createRoom(
var ev *gomatrixserverlib.Event var ev *gomatrixserverlib.Event
ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("buildEvent failed") util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -448,7 +449,7 @@ func createRoom(
builtEvents = append(builtEvents, ev.Headered(roomVersion)) builtEvents = append(builtEvents, ev.Headered(roomVersion))
err = authEvents.AddEvent(ev) err = authEvents.AddEvent(ev)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed") util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
} }
@ -462,8 +463,8 @@ func createRoom(
SendAsServer: roomserverAPI.DoNotSendToOtherServers, SendAsServer: roomserverAPI.DoNotSendToOtherServers,
}) })
} }
if err = roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, inputs, false); err != nil { if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -478,9 +479,9 @@ func createRoom(
} }
var aliasResp roomserverAPI.SetRoomAliasResponse var aliasResp roomserverAPI.SetRoomAliasResponse
err = rsAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -519,11 +520,11 @@ func createRoom(
for _, invitee := range r.Invite { for _, invitee := range r.Invite {
// Build the invite event. // Build the invite event.
inviteEvent, err := buildMembershipEvent( inviteEvent, err := buildMembershipEvent(
req.Context(), invitee, "", accountDB, device, gomatrixserverlib.Invite, ctx, invitee, "", accountDB, device, gomatrixserverlib.Invite,
roomID, true, cfg, evTime, rsAPI, asAPI, roomID, true, cfg, evTime, rsAPI, asAPI,
) )
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed")
continue continue
} }
inviteStrippedState := append( inviteStrippedState := append(
@ -532,7 +533,7 @@ func createRoom(
) )
// Send the invite event to the roomserver. // Send the invite event to the roomserver.
err = roomserverAPI.SendInvite( err = roomserverAPI.SendInvite(
req.Context(), ctx,
rsAPI, rsAPI,
inviteEvent.Headered(roomVersion), inviteEvent.Headered(roomVersion),
inviteStrippedState, // invite room state inviteStrippedState, // invite room state
@ -544,7 +545,7 @@ func createRoom(
return e.JSONResponse() return e.JSONResponse()
case nil: case nil:
default: default:
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
@ -556,13 +557,13 @@ func createRoom(
if r.Visibility == "public" { if r.Visibility == "public" {
// expose this room in the published room list // expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse var pubRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: "public", Visibility: "public",
}, &pubRes) }, &pubRes)
if pubRes.Error != nil { if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(req.Context()).WithError(pubRes.Error).Error("failed to visibility:public") util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")
} }
} }

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomIDOrAlias string, roomIDOrAlias string,
) util.JSONResponse { ) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join. // Prepare to ask the roomserver to perform the room join.

View file

@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -36,7 +36,7 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys( func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive, req *http.Request, userInteractiveAuth *auth.UserInteractive,
keyserverAPI api.KeyInternalAPI, device *userapi.Device, keyserverAPI api.KeyInternalAPI, device *userapi.Device,
accountDB accounts.Database, cfg *config.ClientAPI, accountDB userdb.Database, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
uploadReq := &crossSigningRequest{} uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{} uploadRes := &api.PerformUploadDeviceKeysResponse{}

View file

@ -38,6 +38,12 @@ func LeaveRoomByID(
// Ask the roomserver to perform the leave. // Ask the roomserver to perform the leave.
if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil { if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil {
if leaveRes.Code != 0 {
return util.JSONResponse{
Code: leaveRes.Code,
JSON: jsonerror.LeaveServerNoticeError(),
}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(err.Error()), JSON: jsonerror.Unknown(err.Error()),

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -54,7 +54,7 @@ func passwordLogin() flows {
// Login implements GET and POST /login // Login implements GET and POST /login
func Login( func Login(
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
if req.Method == http.MethodGet { if req.Method == http.MethodGet {

View file

@ -30,7 +30,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -39,7 +39,7 @@ import (
var errMissingUserID = errors.New("'user_id' must be supplied") var errMissingUserID = errors.New("'user_id' must be supplied")
func SendBan( func SendBan(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -81,7 +81,7 @@ func SendBan(
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
} }
func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device, func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device,
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
roomVer gomatrixserverlib.RoomVersion, roomVer gomatrixserverlib.RoomVersion,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
} }
func SendKick( func SendKick(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -165,7 +165,7 @@ func SendKick(
} }
func SendUnban( func SendUnban(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -200,7 +200,7 @@ func SendUnban(
} }
func SendInvite( func SendInvite(
req *http.Request, accountDB accounts.Database, device *userapi.Device, req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
@ -226,27 +226,42 @@ func SendInvite(
} }
} }
// We already received the return value, so no need to check for an error here.
response, _ := sendInvite(req.Context(), accountDB, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime)
return response
}
// sendInvite sends an invitation to a user. Returns a JSONResponse and an error
func sendInvite(
ctx context.Context,
accountDB userdb.Database,
device *userapi.Device,
roomID, userID, reason string,
cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, evTime time.Time,
) (util.JSONResponse, error) {
event, err := buildMembershipEvent( event, err := buildMembershipEvent(
req.Context(), body.UserID, body.Reason, accountDB, device, "invite", ctx, userID, reason, accountDB, device, "invite",
roomID, false, cfg, evTime, rsAPI, asAPI, roomID, false, cfg, evTime, rsAPI, asAPI,
) )
if err == errMissingUserID { if err == errMissingUserID {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()), JSON: jsonerror.BadJSON(err.Error()),
} }, err
} else if err == eventutil.ErrRoomNoExists { } else if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: jsonerror.NotFound(err.Error()), JSON: jsonerror.NotFound(err.Error()),
} }, err
} else if err != nil { } else if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), err
} }
err = roomserverAPI.SendInvite( err = roomserverAPI.SendInvite(
req.Context(), rsAPI, ctx, rsAPI,
event, event,
nil, // ask the roomserver to draw up invite room state for us nil, // ask the roomserver to draw up invite room state for us
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
@ -254,24 +269,24 @@ func SendInvite(
) )
switch e := err.(type) { switch e := err.(type) {
case *roomserverAPI.PerformError: case *roomserverAPI.PerformError:
return e.JSONResponse() return e.JSONResponse(), err
case nil: case nil:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: struct{}{},
} }, nil
default: default:
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
} }, err
} }
} }
func buildMembershipEvent( func buildMembershipEvent(
ctx context.Context, ctx context.Context,
targetUserID, reason string, accountDB accounts.Database, targetUserID, reason string, accountDB userdb.Database,
device *userapi.Device, device *userapi.Device,
membership, roomID string, isDirect bool, membership, roomID string, isDirect bool,
cfg *config.ClientAPI, evTime time.Time, cfg *config.ClientAPI, evTime time.Time,
@ -312,7 +327,7 @@ func loadProfile(
ctx context.Context, ctx context.Context,
userID string, userID string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, accountDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID) _, serverName, err := gomatrixserverlib.SplitID('@', userID)
@ -366,7 +381,7 @@ func checkAndProcessThreepid(
body *threepid.MembershipRequest, body *threepid.MembershipRequest,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomID string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) { ) (inviteStored bool, errRes *util.JSONResponse) {

View file

@ -9,7 +9,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -29,7 +29,7 @@ type newPasswordAuth struct {
func Password( func Password(
req *http.Request, req *http.Request,
userAPI api.UserInternalAPI, userAPI api.UserInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
device *api.Device, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -19,7 +19,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomIDOrAlias string, roomIDOrAlias string,
) util.JSONResponse { ) util.JSONResponse {
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
@ -82,7 +82,7 @@ func UnpeekRoomByID(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{ unpeekReq := roomserverAPI.PerformUnpeekRequest{

View file

@ -27,7 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID} // GetProfile implements GET /profile/{userID}
func GetProfile( func GetProfile(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -65,7 +65,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url // GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL( func GetAvatarURL(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -92,7 +92,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url // SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL( func SetAvatarURL(
req *http.Request, accountDB accounts.Database, req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -182,7 +182,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname // GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName( func GetDisplayName(
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -209,7 +209,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName( func SetDisplayName(
req *http.Request, accountDB accounts.Database, req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -302,7 +302,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically // Returns an error when something goes wrong or specifically
// eventutil.ErrProfileNoExists when the profile doesn't exist. // eventutil.ErrProfileNoExists when the profile doesn't exist.
func getProfile( func getProfile(
ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI, ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, userID string,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,

View file

@ -32,18 +32,19 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
) )
var ( var (
@ -153,7 +154,7 @@ type authDict struct {
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api
type userInteractiveResponse struct { type userInteractiveResponse struct {
Flows []authtypes.Flow `json:"flows"` Flows []authtypes.Flow `json:"flows"`
Completed []authtypes.LoginType `json:"completed,omitempty"` Completed []authtypes.LoginType `json:"completed"`
Params map[string]interface{} `json:"params"` Params map[string]interface{} `json:"params"`
Session string `json:"session"` Session string `json:"session"`
} }
@ -447,7 +448,7 @@ func validateApplicationService(
func Register( func Register(
req *http.Request, req *http.Request,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
accountDB accounts.Database, accountDB userdb.Database,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
var r registerRequest var r registerRequest
@ -531,6 +532,13 @@ func handleGuestRegistration(
cfg *config.ClientAPI, cfg *config.ClientAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
if cfg.RegistrationDisabled || cfg.GuestsDisabled {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Guest registration is disabled"),
}
}
var res userapi.PerformAccountCreationResponse var res userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeGuest, AccountType: userapi.AccountTypeGuest,
@ -708,7 +716,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), policyVersion, req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
) )
} }
@ -732,7 +740,7 @@ func checkAndCompleteFlow(
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), policyVersion, req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
) )
} }
@ -757,6 +765,7 @@ func completeRegistration(
username, password, appserviceID, ipAddr, userAgent, policyVersion string, username, password, appserviceID, ipAddr, userAgent, policyVersion string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType,
) util.JSONResponse { ) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
@ -771,13 +780,12 @@ func completeRegistration(
JSON: jsonerror.BadJSON("missing password"), JSON: jsonerror.BadJSON("missing password"),
} }
} }
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AppServiceID: appserviceID, AppServiceID: appserviceID,
Localpart: username, Localpart: username,
Password: password, Password: password,
AccountType: userapi.AccountTypeUser, AccountType: accType,
OnConflict: userapi.ConflictAbort, OnConflict: userapi.ConflictAbort,
PolicyVersion: policyVersion, PolicyVersion: policyVersion,
}, &accRes) }, &accRes)
@ -904,7 +912,7 @@ type availableResponse struct {
func RegisterAvailable( func RegisterAvailable(
req *http.Request, req *http.Request,
cfg *config.ClientAPI, cfg *config.ClientAPI,
accountDB accounts.Database, accountDB userdb.Database,
) util.JSONResponse { ) util.JSONResponse {
username := req.URL.Query().Get("username") username := req.URL.Query().Get("username")
@ -976,5 +984,10 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
return *resErr return *resErr
} }
deviceID := "shared_secret_registration" deviceID := "shared_secret_registration"
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID)
accType := userapi.AccountTypeUser
if ssrr.Admin {
accType = userapi.AccountTypeAdmin
}
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
} }

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
@ -34,7 +35,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -51,7 +52,7 @@ func Setup(
eduAPI eduServerAPI.EDUServerInputAPI, eduAPI eduServerAPI.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
accountDB accounts.Database, accountDB userdb.Database,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer, syncProducer *producers.SyncAPIProducer,
@ -117,6 +118,58 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
} }
// server notifications
if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, accountDB, cfg)
if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
}
synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}",
httputil.MakeAuthAPI("send_server_notice", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
// not specced, but ensure we're rate limiting requests to this endpoint
if r := rateLimits.Limit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
txnID := vars["txnID"]
return SendServerNotice(
req, &cfg.Matrix.ServerNotices,
cfg, userAPI, rsAPI, accountDB, asAPI,
device, serverNotificationSender,
&txnID, transactionsCache,
)
}),
).Methods(http.MethodPut, http.MethodOptions)
synapseAdminRouter.Handle("/admin/v1/send_server_notice",
httputil.MakeAuthAPI("send_server_notice", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
// not specced, but ensure we're rate limiting requests to this endpoint
if r := rateLimits.Limit(req); r != nil {
return *r
}
return SendServerNotice(
req, &cfg.Matrix.ServerNotices,
cfg, userAPI, rsAPI, accountDB, asAPI,
device, serverNotificationSender,
nil, transactionsCache,
)
}),
).Methods(http.MethodPost, http.MethodOptions)
}
// You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables.
// So make a named path variable called 'apiversion' (which we will never read in handlers) and then do
// (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group
// from a match (gorilla/mux exposes no way to do this) so it demands you make it a non-capturing group
// using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching.
// Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing!
v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
// unspecced consent tracking // unspecced consent tracking
if cfg.Matrix.UserConsentOptions.Enabled { if cfg.Matrix.UserConsentOptions.Enabled {
consentAPIMux.Handle("/consent", consentAPIMux.Handle("/consent",
@ -129,12 +182,12 @@ func Setup(
r0mux := publicAPIMux.PathPrefix("/r0").Subrouter() r0mux := publicAPIMux.PathPrefix("/r0").Subrouter()
unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter()
r0mux.Handle("/createRoom", v3mux.Handle("/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("createRoom", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI) return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/join/{roomIDOrAlias}", v3mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -150,7 +203,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
if mscCfg.Enabled("msc2753") { if mscCfg.Enabled("msc2753") {
r0mux.Handle("/peek/{roomIDOrAlias}", v3mux.Handle("/peek/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -165,12 +218,12 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
} }
r0mux.Handle("/joined_rooms", v3mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("joined_rooms", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, rsAPI) return GetJoinedRooms(req, device, rsAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/join", v3mux.Handle("/rooms/{roomID}/join",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -184,7 +237,7 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/leave", v3mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -198,7 +251,7 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/unpeek", v3mux.Handle("/rooms/{roomID}/unpeek",
httputil.MakeAuthAPI("unpeek", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("unpeek", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -209,7 +262,7 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/ban", v3mux.Handle("/rooms/{roomID}/ban",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -218,7 +271,7 @@ func Setup(
return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/invite", v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -230,7 +283,7 @@ func Setup(
return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/kick", v3mux.Handle("/rooms/{roomID}/kick",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -239,7 +292,7 @@ func Setup(
return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/unban", v3mux.Handle("/rooms/{roomID}/unban",
httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -248,7 +301,7 @@ func Setup(
return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}", v3mux.Handle("/rooms/{roomID}/send/{eventType}",
httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -257,7 +310,7 @@ func Setup(
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -268,7 +321,7 @@ func Setup(
nil, cfg, rsAPI, transactionsCache) nil, cfg, rsAPI, transactionsCache)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/event/{eventID}", v3mux.Handle("/rooms/{roomID}/event/{eventID}",
httputil.MakeAuthAPI("rooms_get_event", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_get_event", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -278,7 +331,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -286,7 +339,7 @@ func Setup(
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -294,7 +347,7 @@ func Setup(
return GetAliases(req, rsAPI, device, vars["roomID"]) return GetAliases(req, rsAPI, device, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -305,7 +358,7 @@ func Setup(
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -314,7 +367,7 @@ func Setup(
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -326,7 +379,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -337,21 +390,21 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return Register(req, userAPI, accountDB, cfg) return Register(req, userAPI, accountDB, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
} }
return RegisterAvailable(req, cfg, accountDB) return RegisterAvailable(req, cfg, accountDB)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -361,7 +414,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_room", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -371,7 +424,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_room", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -380,7 +433,7 @@ func Setup(
return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI) return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI)
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/directory/list/room/{roomID}", v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -390,7 +443,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// TODO: Add AS support // TODO: Add AS support
r0mux.Handle("/directory/list/room/{roomID}", v3mux.Handle("/directory/list/room/{roomID}",
httputil.MakeAuthAPI("directory_list", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_list", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -399,25 +452,25 @@ func Setup(
return SetVisibility(req, rsAPI, device, vars["roomID"]) return SetVisibility(req, rsAPI, device, vars["roomID"])
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/publicRooms", v3mux.Handle("/publicRooms",
httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg) return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout", v3mux.Handle("/logout",
httputil.MakeAuthAPI("logout", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("logout", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Logout(req, userAPI, device) return Logout(req, userAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout/all", v3mux.Handle("/logout/all",
httputil.MakeAuthAPI("logout", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("logout", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return LogoutAll(req, userAPI, device) return LogoutAll(req, userAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/typing/{userID}", v3mux.Handle("/rooms/{roomID}/typing/{userID}",
httputil.MakeAuthAPI("rooms_typing", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_typing", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -429,7 +482,7 @@ func Setup(
return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI) return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}", v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
httputil.MakeAuthAPI("rooms_redact", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_redact", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -438,7 +491,7 @@ func Setup(
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
httputil.MakeAuthAPI("rooms_redact", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_redact", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -448,7 +501,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/sendToDevice/{eventType}/{txnID}", v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_to_device", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_to_device", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -473,7 +526,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/account/whoami", v3mux.Handle("/account/whoami",
httputil.MakeAuthAPI("whoami", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("whoami", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -482,7 +535,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/password", v3mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("password", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -491,7 +544,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/account/deactivate", v3mux.Handle("/account/deactivate",
httputil.MakeAuthAPI("deactivate", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("deactivate", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -502,7 +555,7 @@ func Setup(
// Stub endpoints required by Element // Stub endpoints required by Element
r0mux.Handle("/login", v3mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -511,14 +564,14 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/auth/{authType}/fallback/web", v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
vars := mux.Vars(req) vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg) return AuthFallback(w, req, vars["authType"], cfg)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/pushrules/", v3mux.Handle("/pushrules/",
httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
// TODO: Implement push rules API // TODO: Implement push rules API
res := json.RawMessage(`{ res := json.RawMessage(`{
@ -539,7 +592,7 @@ func Setup(
// Element user settings // Element user settings
r0mux.Handle("/profile/{userID}", v3mux.Handle("/profile/{userID}",
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -549,7 +602,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url", v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -559,7 +612,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url", v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeAuthAPI("profile_avatar_url", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("profile_avatar_url", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -574,7 +627,7 @@ func Setup(
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method // PUT requests, so we need to allow this method
r0mux.Handle("/profile/{userID}/displayname", v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -584,7 +637,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/displayname", v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeAuthAPI("profile_displayname", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("profile_displayname", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -599,13 +652,13 @@ func Setup(
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows // Browsers use the OPTIONS HTTP method to check if the CORS policy allows
// PUT requests, so we need to allow this method // PUT requests, so we need to allow this method
r0mux.Handle("/account/3pid", v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAssociated3PIDs(req, accountDB, device) return GetAssociated3PIDs(req, accountDB, device)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/3pid", v3mux.Handle("/account/3pid",
httputil.MakeAuthAPI("account_3pid", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("account_3pid", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) return CheckAndSave3PIDAssociation(req, accountDB, device, cfg)
}), }),
@ -617,14 +670,14 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
return RequestEmailToken(req, accountDB, cfg) return RequestEmailToken(req, accountDB, cfg)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// Element logs get flooded unless this is handled // Element logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status", v3mux.Handle("/presence/{userID}/status",
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -637,7 +690,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/voip/turnServer", v3mux.Handle("/voip/turnServer",
httputil.MakeAuthAPI("turn_server", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("turn_server", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -646,7 +699,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/thirdparty/protocols", v3mux.Handle("/thirdparty/protocols",
httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
// TODO: Return the third party protcols // TODO: Return the third party protcols
return util.JSONResponse{ return util.JSONResponse{
@ -656,7 +709,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/initialSync", v3mux.Handle("/rooms/{roomID}/initialSync",
httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse {
// TODO: Allow people to peek into rooms. // TODO: Allow people to peek into rooms.
return util.JSONResponse{ return util.JSONResponse{
@ -666,7 +719,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}", v3mux.Handle("/user/{userID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -676,7 +729,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -686,7 +739,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}", v3mux.Handle("/user/{userID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -696,7 +749,7 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("user_account_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -706,7 +759,7 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/admin/whois/{userID}", v3mux.Handle("/admin/whois/{userID}",
httputil.MakeAuthAPI("admin_whois", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("admin_whois", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -716,7 +769,7 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/user/{userID}/openid/request_token", v3mux.Handle("/user/{userID}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("openid_request_token", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -729,7 +782,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user_directory/search", v3mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("userdirectory_search", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -754,7 +807,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/members", v3mux.Handle("/rooms/{roomID}/members",
httputil.MakeAuthAPI("rooms_members", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_members", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -764,7 +817,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/joined_members", v3mux.Handle("/rooms/{roomID}/joined_members",
httputil.MakeAuthAPI("rooms_members", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_members", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -774,7 +827,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/read_markers", v3mux.Handle("/rooms/{roomID}/read_markers",
httputil.MakeAuthAPI("rooms_read_markers", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_read_markers", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -787,7 +840,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/forget", v3mux.Handle("/rooms/{roomID}/forget",
httputil.MakeAuthAPI("rooms_forget", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_forget", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -800,13 +853,13 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/devices", v3mux.Handle("/devices",
httputil.MakeAuthAPI("get_devices", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_devices", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, userAPI, device) return GetDevicesByLocalpart(req, userAPI, device)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("get_device", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_device", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -816,7 +869,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("device_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("device_data", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -826,7 +879,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}", v3mux.Handle("/devices/{deviceID}",
httputil.MakeAuthAPI("delete_device", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("delete_device", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -836,14 +889,14 @@ func Setup(
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/delete_devices", v3mux.Handle("/delete_devices",
httputil.MakeAuthAPI("delete_devices", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("delete_devices", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return DeleteDevices(req, userAPI, device) return DeleteDevices(req, userAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest // Stub implementations for sytest
r0mux.Handle("/events", v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"chunk": []interface{}{}, "chunk": []interface{}{},
@ -853,7 +906,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/initialSync", v3mux.Handle("/initialSync",
httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"end": "", "end": "",
@ -861,7 +914,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags",
httputil.MakeAuthAPI("get_tags", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_tags", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -871,7 +924,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
httputil.MakeAuthAPI("put_tag", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("put_tag", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -881,7 +934,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
httputil.MakeAuthAPI("delete_tag", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("delete_tag", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -891,7 +944,7 @@ func Setup(
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/capabilities", v3mux.Handle("/capabilities",
httputil.MakeAuthAPI("capabilities", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("capabilities", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r
@ -934,11 +987,11 @@ func Setup(
return CreateKeyBackupVersion(req, userAPI, device) return CreateKeyBackupVersion(req, userAPI, device)
}) })
r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) v3mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut)
r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) v3mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete)
r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions)
@ -1030,9 +1083,9 @@ func Setup(
return UploadBackupKeys(req, userAPI, device, version, &keyReq) return UploadBackupKeys(req, userAPI, device, version, &keyReq)
}) })
r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut)
unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut)
unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut)
@ -1060,9 +1113,9 @@ func Setup(
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
}) })
r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions)
@ -1080,29 +1133,29 @@ func Setup(
return UploadCrossSigningDeviceSignatures(req, keyAPI, device) return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
}) })
r0mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
// Supplying a device ID is deprecated. // Supplying a device ID is deprecated.
r0mux.Handle("/keys/upload/{deviceID}", v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/upload", v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, cfg.Matrix.UserConsentOptions, true, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device) return UploadKeys(req, keyAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/query", v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_query", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, keyAPI, device) return QueryKeys(req, keyAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/claim", v3mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_claim", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return ClaimKeys(req, keyAPI) return ClaimKeys(req, keyAPI)
}), }),

View file

@ -15,10 +15,16 @@
package routing package routing
import ( import (
"context"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -26,10 +32,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
) )
// http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid
@ -97,7 +99,22 @@ func SendEvent(
defer mutex.(*sync.Mutex).Unlock() defer mutex.(*sync.Mutex).Unlock()
startedGeneratingEvent := time.Now() startedGeneratingEvent := time.Now()
e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI)
var r map[string]interface{} // must be a JSON object
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, cfg, rsAPI, evTime)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
@ -153,27 +170,16 @@ func SendEvent(
} }
func generateSendEvent( func generateSendEvent(
req *http.Request, ctx context.Context,
r map[string]interface{},
device *userapi.Device, device *userapi.Device,
roomID, eventType string, stateKey *string, roomID, eventType string, stateKey *string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
evTime time.Time,
) (*gomatrixserverlib.Event, *util.JSONResponse) { ) (*gomatrixserverlib.Event, *util.JSONResponse) {
// parse the incoming http request // parse the incoming http request
userID := device.UserID userID := device.UserID
var r map[string]interface{} // must be a JSON object
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return nil, resErr
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
// create the new event and set all the fields we can // create the new event and set all the fields we can
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
@ -182,15 +188,15 @@ func generateSendEvent(
Type: eventType, Type: eventType,
StateKey: stateKey, StateKey: stateKey,
} }
err = builder.SetContent(r) err := builder.SetContent(r)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -213,7 +219,7 @@ func generateSendEvent(
JSON: jsonerror.BadJSON(e.Error()), JSON: jsonerror.BadJSON(e.Error()),
} }
} else if err != nil { } else if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed")
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }

View file

@ -20,7 +20,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -33,7 +33,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer // sends the typing events to client API typingProducer
func SendTyping( func SendTyping(
req *http.Request, device *userapi.Device, roomID string, req *http.Request, device *userapi.Device, roomID string,
userID string, accountDB accounts.Database, userID string, accountDB userdb.Database,
eduAPI api.EDUServerInputAPI, eduAPI api.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {

View file

@ -0,0 +1,343 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/transactions"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
// Unspecced server notice request
// https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/server_notices.md
type sendServerNoticeRequest struct {
UserID string `json:"user_id,omitempty"`
Content struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
} `json:"content,omitempty"`
Type string `json:"type,omitempty"`
StateKey string `json:"state_key,omitempty"`
}
// SendServerNotice sends a message to a specific user. It can only be invoked by an admin.
func SendServerNotice(
req *http.Request,
cfgNotices *config.ServerNotices,
cfgClient *config.ClientAPI,
userAPI userapi.UserInternalAPI,
rsAPI api.RoomserverInternalAPI,
accountsDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI,
device *userapi.Device,
senderDevice *userapi.Device,
txnID *string,
txnCache *transactions.Cache,
) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
if txnID != nil {
// Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
return *res
}
}
ctx := req.Context()
var r sendServerNoticeRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
// check that all required fields are set
if !r.valid() {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid request"),
}
}
// get rooms for specified user
allUserRooms := []string{}
userRooms := api.QueryRoomsForUserResponse{}
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: r.UserID,
WantMembership: "join",
}, &userRooms); err != nil {
return util.ErrorResponse(err)
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
// get invites for specified user
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: r.UserID,
WantMembership: "invite",
}, &userRooms); err != nil {
return util.ErrorResponse(err)
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
// get left rooms for specified user
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: r.UserID,
WantMembership: "leave",
}, &userRooms); err != nil {
return util.ErrorResponse(err)
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
// get rooms of the sender
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName)
senderRooms := api.QueryRoomsForUserResponse{}
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: senderUserID,
WantMembership: "join",
}, &senderRooms); err != nil {
return util.ErrorResponse(err)
}
// check if we have rooms in common
commonRooms := []string{}
for _, userRoomID := range allUserRooms {
for _, senderRoomID := range senderRooms.RoomIDs {
if userRoomID == senderRoomID {
commonRooms = append(commonRooms, senderRoomID)
}
}
}
if len(commonRooms) > 1 {
return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms)))
}
var (
roomID string
roomVersion = gomatrixserverlib.RoomVersionV6
)
// create a new room for the user
if len(commonRooms) == 0 {
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID)
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
pl, err := json.Marshal(powerLevelContent)
if err != nil {
return util.ErrorResponse(err)
}
createContent := map[string]interface{}{}
createContent["m.federate"] = false
cc, err := json.Marshal(createContent)
if err != nil {
return util.ErrorResponse(err)
}
crReq := createRoomRequest{
Invite: []string{r.UserID},
Name: cfgNotices.RoomName,
Visibility: "private",
Preset: presetPrivateChat,
CreationContent: cc,
GuestCanJoin: false,
RoomVersion: roomVersion,
PowerLevelContentOverride: pl,
}
roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, accountsDB, rsAPI, asAPI, time.Now())
switch data := roomRes.JSON.(type) {
case createRoomResponse:
roomID = data.RoomID
// tag the room, so we can later check if the user tries to reject an invite
serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{
"m.server_notice": {
Order: 1.0,
},
}}
if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil {
util.GetLogger(ctx).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
}
default:
// if we didn't get a createRoomResponse, we probably received an error, so return that.
return roomRes
}
} else {
// we've found a room in common, check the membership
roomID = commonRooms[0]
// re-invite the user
res, err := sendInvite(ctx, accountsDB, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
if err != nil {
return res
}
}
startedGeneratingEvent := time.Now()
request := map[string]interface{}{
"body": r.Content.Body,
"msgtype": r.Content.MsgType,
}
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now())
if resErr != nil {
logrus.Errorf("failed to send message: %+v", resErr)
return *resErr
}
timeToGenerateEvent := time.Since(startedGeneratingEvent)
var txnAndSessionID *api.TransactionID
if txnID != nil {
txnAndSessionID = &api.TransactionID{
TransactionID: *txnID,
SessionID: device.SessionID,
}
}
// pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded
startedSubmittingEvent := time.Now()
if err := api.SendEvents(
ctx, rsAPI,
api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{
e.Headered(roomVersion),
},
cfgClient.Matrix.ServerName,
cfgClient.Matrix.ServerName,
txnAndSessionID,
false,
); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError()
}
util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": e.EventID(),
"room_id": roomID,
"room_version": roomVersion,
}).Info("Sent event to roomserver")
timeToSubmitEvent := time.Since(startedSubmittingEvent)
res := util.JSONResponse{
Code: http.StatusOK,
JSON: sendEventResponse{e.EventID()},
}
// Add response to transactionsCache
if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
}
// Take a note of how long it took to generate the event vs submit
// it to the roomserver.
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds()))
return res
}
func (r sendServerNoticeRequest) valid() (ok bool) {
if r.UserID == "" {
return false
}
if r.Content.MsgType == "" || r.Content.Body == "" {
return false
}
return true
}
// getSenderDevice creates a user account to be used when sending server notices.
// It returns an userapi.Device, which is used for building the event
func getSenderDevice(
ctx context.Context,
userAPI userapi.UserInternalAPI,
accountDB userdb.Database,
cfg *config.ClientAPI,
) (*userapi.Device, error) {
var accRes userapi.PerformAccountCreationResponse
// create account if it doesn't exist
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser,
Localpart: cfg.Matrix.ServerNotices.LocalPart,
OnConflict: userapi.ConflictUpdate,
}, &accRes)
if err != nil {
return nil, err
}
// set the avatarurl for the user
if err = accountDB.SetAvatarURL(ctx, cfg.Matrix.ServerNotices.LocalPart, cfg.Matrix.ServerNotices.AvatarURL); err != nil {
util.GetLogger(ctx).WithError(err).Error("accountDB.SetAvatarURL failed")
return nil, err
}
// Check if we got existing devices
deviceRes := &userapi.QueryDevicesResponse{}
err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{
UserID: accRes.Account.UserID,
}, deviceRes)
if err != nil {
return nil, err
}
if len(deviceRes.Devices) > 0 {
return &deviceRes.Devices[0], nil
}
// create an AccessToken
token, err := tokens.GenerateLoginToken(tokens.TokenOptions{
ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(),
ServerName: string(cfg.Matrix.ServerName),
UserID: accRes.Account.UserID,
})
if err != nil {
return nil, err
}
// create a new device, if we didn't find any
var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
AccessToken: token,
NoDeviceListUpdate: true,
}, &devRes)
if err != nil {
return nil, err
}
return devRes.Device, nil
}

View file

@ -0,0 +1,83 @@
package routing
import (
"testing"
)
func Test_sendServerNoticeRequest_validate(t *testing.T) {
type fields struct {
UserID string `json:"user_id,omitempty"`
Content struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
} `json:"content,omitempty"`
Type string `json:"type,omitempty"`
StateKey string `json:"state_key,omitempty"`
}
content := struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
}{
MsgType: "m.text",
Body: "Hello world!",
}
tests := []struct {
name string
fields fields
wantOk bool
}{
{
name: "empty request",
fields: fields{},
},
{
name: "msgtype empty",
fields: fields{
UserID: "@alice:localhost",
Content: struct {
MsgType string `json:"msgtype,omitempty"`
Body string `json:"body,omitempty"`
}{
Body: "Hello world!",
},
},
},
{
name: "msg body empty",
fields: fields{
UserID: "@alice:localhost",
},
},
{
name: "statekey empty",
fields: fields{
UserID: "@alice:localhost",
Content: content,
},
wantOk: true,
},
{
name: "type empty",
fields: fields{
UserID: "@alice:localhost",
Content: content,
},
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := sendServerNoticeRequest{
UserID: tt.fields.UserID,
Content: tt.fields.Content,
Type: tt.fields.Type,
StateKey: tt.fields.StateKey,
}
if gotOk := r.valid(); gotOk != tt.wantOk {
t.Errorf("valid() = %v, want %v", gotOk, tt.wantOk)
}
})
}
}

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -40,7 +40,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements: // RequestEmailToken implements:
// POST /account/3pid/email/requestToken // POST /account/3pid/email/requestToken
// POST /register/email/requestToken // POST /register/email/requestToken
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse { func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse {
var body threepid.EmailAssociationRequest var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr
@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{ JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_IN_USE", ErrCode: "M_THREEPID_IN_USE",
Err: accounts.Err3PIDInUse.Error(), Err: userdb.Err3PIDInUse.Error(),
}, },
} }
} }
@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
// CheckAndSave3PIDAssociation implements POST /account/3pid // CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation( func CheckAndSave3PIDAssociation(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, accountDB userdb.Database, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest var body threepid.EmailAssociationCheckRequest
@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid // GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, accountDB userdb.Database, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
@ -170,7 +170,7 @@ func GetAssociated3PIDs(
} }
// Forget3PID implements POST /account/3pid/delete // Forget3PID implements POST /account/3pid/delete
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse { func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse {
var body authtypes.ThreePID var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr return *reqErr

View file

@ -21,7 +21,9 @@ import (
// whoamiResponse represents an response for a `whoami` request // whoamiResponse represents an response for a `whoami` request
type whoamiResponse struct { type whoamiResponse struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
IsGuest bool `json:"is_guest"`
} }
// Whoami implements `/account/whoami` which enables client to query their account user id. // Whoami implements `/account/whoami` which enables client to query their account user id.
@ -29,6 +31,10 @@ type whoamiResponse struct {
func Whoami(req *http.Request, device *api.Device) util.JSONResponse { func Whoami(req *http.Request, device *api.Device) util.JSONResponse {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: whoamiResponse{UserID: device.UserID}, JSON: whoamiResponse{
UserID: device.UserID,
DeviceID: device.ID,
IsGuest: device.AccountType == api.AccountTypeGuest,
},
} }
} }

View file

@ -29,7 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -87,7 +87,7 @@ var (
func CheckAndProcessInvite( func CheckAndProcessInvite(
ctx context.Context, ctx context.Context,
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, db accounts.Database, rsAPI api.RoomserverInternalAPI, db userdb.Database,
roomID string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed. // Returns an error if a check or a request failed.
func queryIDServer( func queryIDServer(
ctx context.Context, ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil { if err = isTrusted(body.IDServer, cfg); err != nil {
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed. // Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite( func queryIDServerStoreInvite(
ctx context.Context, ctx context.Context,
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string, body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) { ) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name // Retrieve the sender's profile to get their display name

View file

@ -23,12 +23,14 @@ import (
"os" "os"
"strings" "strings"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/term" "golang.org/x/term"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
userdb "github.com/matrix-org/dendrite/userapi/storage"
) )
const usage = `Usage: %s const usage = `Usage: %s
@ -57,6 +59,7 @@ var (
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use") askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
isAdmin = flag.Bool("admin", false, "Create an admin account")
) )
func main() { func main() {
@ -74,19 +77,28 @@ func main() {
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ accountDB, err := userdb.NewDatabase(
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, &config.DatabaseOptions{
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
},
cfg.Global.ServerName, bcrypt.DefaultCost,
cfg.UserAPI.OpenIDTokenLifetimeMS,
api.DefaultLoginTokenLifetime,
)
if err != nil { if err != nil {
logrus.Fatalln("Failed to connect to the database:", err.Error()) logrus.Fatalln("Failed to connect to the database:", err.Error())
} }
accType := api.AccountTypeUser
if *isAdmin {
accType = api.AccountTypeAdmin
}
policyVersion := "" policyVersion := ""
if cfg.Global.UserConsentOptions.Enabled { if cfg.Global.UserConsentOptions.Enabled {
policyVersion = cfg.Global.UserConsentOptions.Version policyVersion = cfg.Global.UserConsentOptions.Version
} }
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion) _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion, accType)
if err != nil { if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error()) logrus.Fatalln("Failed to create the account:", err.Error())
} }

View file

@ -126,7 +126,6 @@ func main() {
cfg.FederationAPI.FederationMaxRetries = 6 cfg.FederationAPI.FederationMaxRetries = 6
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -160,7 +160,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -79,7 +79,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))

View file

@ -132,6 +132,7 @@ func main() {
// dependency. Other components also need updating after their dependencies are up. // dependency. Other components also need updating after their dependencies are up.
rsImpl.SetFederationAPI(fsAPI, keyRing) rsImpl.SetFederationAPI(fsAPI, keyRing)
rsImpl.SetAppserviceAPI(asAPI) rsImpl.SetAppserviceAPI(asAPI)
rsImpl.SetUserAPI(userAPI)
keyImpl.SetUserAPI(userAPI) keyImpl.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(

View file

@ -164,7 +164,6 @@ func startup() {
cfg.Defaults(true) cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"

View file

@ -167,7 +167,6 @@ func main() {
cfg.Defaults(true) cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"

View file

@ -32,7 +32,6 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI)
} }
cfg.Global.TrustedIDServers = []string{ cfg.Global.TrustedIDServers = []string{
"matrix.org", "matrix.org",
@ -91,6 +90,7 @@ func main() {
cfg.Logging[0].Type = "std" cfg.Logging[0].Type = "std"
cfg.UserAPI.BCryptCost = bcrypt.MinCost cfg.UserAPI.BCryptCost = bcrypt.MinCost
cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.InMemory = true
cfg.ClientAPI.RegistrationSharedSecret = "complement"
} }
j, err := yaml.Marshal(cfg) j, err := yaml.Marshal(cfg)

View file

@ -8,12 +8,11 @@ import (
"log" "log"
"os" "os"
pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/pressly/goose" "github.com/pressly/goose"
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -26,8 +25,7 @@ const (
RoomServer = "roomserver" RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver" SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi" SyncAPI = "syncapi"
UserAPIAccounts = "userapi_accounts" UserAPI = "userapi"
UserAPIDevices = "userapi_devices"
) )
var ( var (
@ -35,7 +33,7 @@ var (
flags = flag.NewFlagSet("goose", flag.ExitOnError) flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name") component = flags.String("component", "", "dendrite component name")
knownDBs = []string{ knownDBs = []string{
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices, AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
} }
) )
@ -143,18 +141,14 @@ Commands:
func loadSQLiteDeltas(component string) { func loadSQLiteDeltas(component string) {
switch component { switch component {
case UserAPIAccounts: case UserAPI:
slaccounts.LoadFromGoose() slusers.LoadFromGoose()
case UserAPIDevices:
sldevices.LoadFromGoose()
} }
} }
func loadPostgresDeltas(component string) { func loadPostgresDeltas(component string) {
switch component { switch component {
case UserAPIAccounts: case UserAPI:
pgaccounts.LoadFromGoose() pgusers.LoadFromGoose()
case UserAPIDevices:
pgdevices.LoadFromGoose()
} }
} }

View file

@ -68,6 +68,18 @@ global:
# to other servers and the federation API will not be exposed. # to other servers and the federation API will not be exposed.
disable_federation: false disable_federation: false
# Server notices allows server admins to send messages to all users.
server_notices:
enabled: false
# The server localpart to be used when sending notices, ensure this is not yet taken
local_part: "_server"
# The displayname to be used when sending notices
display_name: "Server alerts"
# The mxid of the avatar to use
avatar_url: ""
# The roomname to be used when creating messages
room_name: "Server Alerts"
# Consent tracking configuration # Consent tracking configuration
user_consent: user_consent:
# If the user consent tracking is enabled or not # If the user consent tracking is enabled or not
@ -169,6 +181,10 @@ client_api:
# using the registration shared secret below. # using the registration shared secret below.
registration_disabled: false registration_disabled: false
# Prevents new guest accounts from being created. Guest registration is also
# disabled implicitly by setting 'registration_disabled' above.
guests_disabled: true
# If set, allows registration by anyone who knows the shared secret, regardless of # If set, allows registration by anyone who knows the shared secret, regardless of
# whether registration is otherwise disabled. # whether registration is otherwise disabled.
registration_shared_secret: "" registration_shared_secret: ""
@ -231,13 +247,6 @@ federation_api:
# enable this option in production as it presents a security risk! # enable this option in production as it presents a security risk!
disable_tls_validation: false disable_tls_validation: false
# Use the following proxy server for outbound federation traffic.
proxy_outbound:
enabled: false
protocol: http
host: localhost
port: 8080
# Perspective keyservers to use as a backup when direct key fetches fail. This may # Perspective keyservers to use as a backup when direct key fetches fail. This may
# be required to satisfy key requests for servers that are no longer online when # be required to satisfy key requests for servers that are no longer online when
# joining some rooms. # joining some rooms.

4
go.mod
View file

@ -1,6 +1,6 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
@ -45,7 +45,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.10 github.com/mattn/go-sqlite3 v1.14.10
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect
github.com/nats-io/nats-server/v2 v2.3.2 github.com/nats-io/nats-server/v2 v2.3.2
github.com/nats-io/nats.go v1.13.1-0.20211122170419-d7c1d78a50fc github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31 github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31

15
go.sum
View file

@ -1122,8 +1122,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nats-io/jwt/v2 v2.2.0 h1:Yg/4WFK6vsqMudRg91eBb7Dh6XeVcDMPHycDE8CfltE= github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY=
github.com/nats-io/jwt/v2 v2.2.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -1132,8 +1132,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 h1:BLQVdjMH5XD4BYb0fa+c2Oh2Nr1vrO7GKvRnIJDxChc= github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8=
github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423/go.mod h1:9sdEkBhyZMQG1M9TevnlYUwMusRACn2vlgOeqoHKwVo= github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
@ -1508,8 +1508,8 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1735,6 +1735,7 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc=
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -1756,10 +1757,10 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View file

@ -7,14 +7,6 @@ import (
) )
const ( const (
RoomServerStateKeyNIDsCacheName = "roomserver_statekey_nids"
RoomServerStateKeyNIDsCacheMaxEntries = 1024
RoomServerStateKeyNIDsCacheMutable = false
RoomServerEventTypeNIDsCacheName = "roomserver_eventtype_nids"
RoomServerEventTypeNIDsCacheMaxEntries = 64
RoomServerEventTypeNIDsCacheMutable = false
RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false RoomServerRoomIDsCacheMutable = false
@ -29,44 +21,10 @@ type RoomServerCaches interface {
// RoomServerNIDsCache contains the subset of functions needed for // RoomServerNIDsCache contains the subset of functions needed for
// a roomserver NID cache. // a roomserver NID cache.
type RoomServerNIDsCache interface { type RoomServerNIDsCache interface {
GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool)
StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID)
GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool)
StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID)
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
} }
func (c Caches) GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) {
val, found := c.RoomServerStateKeyNIDs.Get(stateKey)
if found && val != nil {
if stateKeyNID, ok := val.(types.EventStateKeyNID); ok {
return stateKeyNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) {
c.RoomServerStateKeyNIDs.Set(stateKey, nid)
}
func (c Caches) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) {
val, found := c.RoomServerEventTypeNIDs.Get(eventType)
if found && val != nil {
if eventTypeNID, ok := val.(types.EventTypeNID); ok {
return eventTypeNID, true
}
}
return 0, false
}
func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) {
c.RoomServerEventTypeNIDs.Set(eventType, nid)
}
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID)))
if found && val != nil { if found && val != nil {

View file

@ -4,14 +4,12 @@ package caching
// different implementations as long as they satisfy the Cache // different implementations as long as they satisfy the Cache
// interface. // interface.
type Caches struct { type Caches struct {
RoomVersions Cache // RoomVersionCache RoomVersions Cache // RoomVersionCache
ServerKeys Cache // ServerKeyCache ServerKeys Cache // ServerKeyCache
RoomServerStateKeyNIDs Cache // RoomServerNIDsCache RoomServerRoomNIDs Cache // RoomServerNIDsCache
RoomServerEventTypeNIDs Cache // RoomServerNIDsCache RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomServerRoomNIDs Cache // RoomServerNIDsCache RoomInfos Cache // RoomInfoCache
RoomServerRoomIDs Cache // RoomServerNIDsCache FederationEvents Cache // FederationEventsCache
RoomInfos Cache // RoomInfoCache
FederationEvents Cache // FederationEventsCache
} }
// Cache is the interface that an implementation must satisfy. // Cache is the interface that an implementation must satisfy.

View file

@ -28,24 +28,6 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
roomServerStateKeyNIDs, err := NewInMemoryLRUCachePartition(
RoomServerStateKeyNIDsCacheName,
RoomServerStateKeyNIDsCacheMutable,
RoomServerStateKeyNIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerEventTypeNIDs, err := NewInMemoryLRUCachePartition(
RoomServerEventTypeNIDsCacheName,
RoomServerEventTypeNIDsCacheMutable,
RoomServerEventTypeNIDsCacheMaxEntries,
enablePrometheus,
)
if err != nil {
return nil, err
}
roomServerRoomIDs, err := NewInMemoryLRUCachePartition( roomServerRoomIDs, err := NewInMemoryLRUCachePartition(
RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheName,
RoomServerRoomIDsCacheMutable, RoomServerRoomIDsCacheMutable,
@ -74,18 +56,15 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
return nil, err return nil, err
} }
go cacheCleaner( go cacheCleaner(
roomVersions, serverKeys, roomServerStateKeyNIDs, roomVersions, serverKeys, roomServerRoomIDs,
roomServerEventTypeNIDs, roomServerRoomIDs,
roomInfos, federationEvents, roomInfos, federationEvents,
) )
return &Caches{ return &Caches{
RoomVersions: roomVersions, RoomVersions: roomVersions,
ServerKeys: serverKeys, ServerKeys: serverKeys,
RoomServerStateKeyNIDs: roomServerStateKeyNIDs, RoomServerRoomIDs: roomServerRoomIDs,
RoomServerEventTypeNIDs: roomServerEventTypeNIDs, RoomInfos: roomInfos,
RoomServerRoomIDs: roomServerRoomIDs, FederationEvents: federationEvents,
RoomInfos: roomInfos,
FederationEvents: federationEvents,
}, nil }, nil
} }

View file

@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
cfg.RoomServer.Database.ConnectionString = config.DataSource(database) cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database) cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database)
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
cfg.EDUServer.InternalAPI.Listen = assignAddress() cfg.EDUServer.InternalAPI.Listen = assignAddress()

View file

@ -367,10 +367,13 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
waitTime = fcerr.RetryAfter waitTime = fcerr.RetryAfter
} else if fcerr.Blacklisted { } else if fcerr.Blacklisted {
waitTime = time.Hour * 8 waitTime = time.Hour * 8
} else {
// For all other errors (DNS resolution, network etc.) wait 1 hour.
waitTime = time.Hour
} }
} else { } else {
waitTime = time.Hour waitTime = time.Hour
logger.WithError(err).Warn("GetUserDevices returned unknown error type") logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type")
} }
continue continue
} }

View file

@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
} }
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil) msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err), Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domain := string(serverName) domain := string(serverName)
// query local devices // query local devices
if serverName == a.ThisServer { if serverName == a.ThisServer {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err), Err: fmt.Sprintf("failed to query local device keys: %s", err),
@ -513,6 +513,11 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// drop the error as it's already a failure at this point // drop the error as it's already a failure at this point
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
} }
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
if len(res.DeviceKeys) > 0 {
delete(res.Failures, serverName)
}
respMu.Unlock() respMu.Unlock()
} }
@ -520,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
) error { ) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote. // if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil { if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
@ -549,10 +554,58 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
} }
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
// get a list of devices from the user API that actually exist, as
// we won't store keys for devices that don't exist
uapidevices := &userapi.QueryDevicesResponse{}
if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
return
}
if !uapidevices.UserExists {
res.Error = &api.KeyError{
Err: fmt.Sprintf("user %q does not exist", req.UserID),
}
return
}
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
for _, key := range uapidevices.Devices {
existingDeviceMap[key.ID] = struct{}{}
}
// Get all of the user existing device keys so we can check for changes.
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
}
return
}
// Work out whether we have device keys in the keyserver for devices that
// no longer exist in the user API. This is mostly an exercise to ensure
// that we keep some integrity between the two.
var toClean []gomatrixserverlib.KeyID
for _, k := range existingKeys {
if _, ok := existingDeviceMap[k.DeviceID]; !ok {
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
}
}
if len(toClean) > 0 {
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
} else {
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
}
}
var keysToStore []api.DeviceMessage var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key // assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys { for _, key := range req.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) var serverName gomatrixserverlib.ServerName
_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
if err != nil { if err != nil {
continue // ignore invalid users continue // ignore invalid users
} }
@ -563,6 +616,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
keysToStore = append(keysToStore, key.WithStreamID(0)) keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking continue // deleted keys don't need sanity checking
} }
// check that the device in question actually exists in the user
// API before we try and store a key for it
if _, ok := existingDeviceMap[key.DeviceID]; !ok {
continue
}
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID { if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
@ -578,29 +636,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}) })
} }
// get existing device keys so we can check for changes
existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore {
existingKeys[i] = api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID,
},
}
}
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
}
return
}
if req.OnlyDisplayNameUpdates { if req.OnlyDisplayNameUpdates {
// add the display name field from keysToStore into existingKeys // add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore) keysToStore = appendDisplayNames(existingKeys, keysToStore)
} }
// store the device keys and emit changes // store the device keys and emit changes
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),

View file

@ -53,7 +53,7 @@ type Database interface {
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device. // cross-signing signatures relating to that device.

View file

@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1" "DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt countStreamIDsForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err return nil, err
} }
@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err return err
} }
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
}) })
} }
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
} }
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {

View file

@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1" "DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err return nil, err
} }
@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err return err
} }
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool) deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs { for _, d := range deviceIDs {
deviceIDMap[d] = true deviceIDMap[d] = true
} }
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
} }
// Querying for device keys returns the latest stream IDs // Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil { if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err) t.Fatalf("DeviceKeysForUser returned error: %s", err)
} }

View file

@ -38,7 +38,7 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
} }

View file

@ -3,9 +3,11 @@ package api
import ( import (
"context" "context"
"github.com/matrix-org/gomatrixserverlib"
asAPI "github.com/matrix-org/dendrite/appservice/api" asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
// RoomserverInputAPI is used to write events to the room server. // RoomserverInputAPI is used to write events to the room server.
@ -14,6 +16,7 @@ type RoomserverInternalAPI interface {
// interdependencies between the roomserver and other input APIs // interdependencies between the roomserver and other input APIs
SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing)
SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI)
SetUserAPI(userAPI userapi.UserInternalAPI)
InputRoomEvents( InputRoomEvents(
ctx context.Context, ctx context.Context,

View file

@ -5,10 +5,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
) )
// RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the // RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the
@ -25,6 +27,10 @@ func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceQuer
t.Impl.SetAppserviceAPI(asAPI) t.Impl.SetAppserviceAPI(asAPI)
} }
func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.UserInternalAPI) {
t.Impl.SetUserAPI(userAPI)
}
func (t *RoomserverInternalAPITrace) InputRoomEvents( func (t *RoomserverInternalAPITrace) InputRoomEvents(
ctx context.Context, ctx context.Context,
req *InputRoomEventsRequest, req *InputRoomEventsRequest,

View file

@ -95,6 +95,8 @@ type PerformLeaveRequest struct {
} }
type PerformLeaveResponse struct { type PerformLeaveResponse struct {
Code int `json:"code,omitempty"`
Message interface{} `json:"message,omitempty"`
} }
type PerformInviteRequest struct { type PerformInviteRequest struct {

View file

@ -14,6 +14,8 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -32,6 +34,7 @@ type RoomserverInternalAPI struct {
*perform.Publisher *perform.Publisher
*perform.Backfiller *perform.Backfiller
*perform.Forgetter *perform.Forgetter
ProcessContext *process.ProcessContext
DB storage.Database DB storage.Database
Cfg *config.RoomServer Cfg *config.RoomServer
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
@ -48,12 +51,13 @@ type RoomserverInternalAPI struct {
} }
func NewRoomserverAPI( func NewRoomserverAPI(
cfg *config.RoomServer, roomserverDB storage.Database, consumer nats.JetStreamContext, processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database,
inputRoomEventTopic, outputRoomEventTopic string, caches caching.RoomServerCaches, consumer nats.JetStreamContext, inputRoomEventTopic, outputRoomEventTopic string,
perspectiveServerNames []gomatrixserverlib.ServerName, caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName,
) *RoomserverInternalAPI { ) *RoomserverInternalAPI {
serverACLs := acls.NewServerACLs(roomserverDB) serverACLs := acls.NewServerACLs(roomserverDB)
a := &RoomserverInternalAPI{ a := &RoomserverInternalAPI{
ProcessContext: processCtx,
DB: roomserverDB, DB: roomserverDB,
Cfg: cfg, Cfg: cfg,
Cache: caches, Cache: caches,
@ -83,6 +87,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
r.KeyRing = keyRing r.KeyRing = keyRing
r.Inputer = &input.Inputer{ r.Inputer = &input.Inputer{
ProcessContext: r.ProcessContext,
DB: r.DB, DB: r.DB,
InputRoomEventTopic: r.InputRoomEventTopic, InputRoomEventTopic: r.InputRoomEventTopic,
OutputRoomEventTopic: r.OutputRoomEventTopic, OutputRoomEventTopic: r.OutputRoomEventTopic,
@ -155,6 +160,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
} }
} }
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) {
r.Leaver.UserAPI = userAPI
}
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) {
r.asAPI = asAPI r.asAPI = asAPI
} }

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -59,6 +60,7 @@ var keyContentFields = map[string]string{
} }
type Inputer struct { type Inputer struct {
ProcessContext *process.ProcessContext
DB storage.Database DB storage.Database
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
Durable nats.SubOpt Durable nats.SubOpt
@ -115,7 +117,7 @@ func (r *Inputer) Start() error {
_ = msg.InProgress() // resets the acknowledgement wait timer _ = msg.InProgress() // resets the acknowledgement wait timer
defer eventsInProgress.Delete(index) defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent) action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent)
if err != nil { if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -405,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro
if len(extraEventIDs) == 0 { if len(extraEventIDs) == 0 {
return nil, nil return nil, nil
} }
extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs) extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -16,25 +16,29 @@ package perform
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
fsAPI "github.com/matrix-org/dendrite/federationapi/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
type Leaver struct { type Leaver struct {
Cfg *config.RoomServer Cfg *config.RoomServer
DB storage.Database DB storage.Database
FSAPI fsAPI.FederationInternalAPI FSAPI fsAPI.FederationInternalAPI
UserAPI userapi.UserInternalAPI
Inputer *input.Inputer Inputer *input.Inputer
} }
@ -85,6 +89,31 @@ func (r *Leaver) performLeaveRoomByID(
if host != r.Cfg.Matrix.ServerName { if host != r.Cfg.Matrix.ServerName {
return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID)
} }
// check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{}
if err := r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
UserID: req.UserID,
RoomID: req.RoomID,
DataType: "m.tag",
}, accData); err != nil {
return nil, fmt.Errorf("unable to query account data")
}
if roomData, ok := accData.RoomAccountData[req.RoomID]; ok {
tagData, ok := roomData["m.tag"]
if ok {
tags := gomatrix.TagContent{}
if err = json.Unmarshal(tagData, &tags); err != nil {
return nil, fmt.Errorf("unable to unmarshal tag content")
}
if _, ok = tags.Tags["m.server_notice"]; ok {
// mimic the returned values from Synapse
res.Message = "You cannot reject this invite"
res.Code = 403
return nil, fmt.Errorf("You cannot reject this invite")
}
}
}
} }
// There's no invite pending, so first of all we want to find out // There's no invite pending, so first of all we want to find out

View file

@ -11,6 +11,8 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
) )
@ -90,6 +92,10 @@ func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.Federation
func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) {
} }
// SetUserAPI no-ops in HTTP client mode as there is no chicken/egg scenario
func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) {
}
// SetRoomAlias implements RoomserverAliasAPI // SetRoomAlias implements RoomserverAliasAPI
func (h *httpRoomserverInternalAPI) SetRoomAlias( func (h *httpRoomserverInternalAPI) SetRoomAlias(
ctx context.Context, ctx context.Context,

View file

@ -53,7 +53,7 @@ func NewInternalAPI(
js := jetstream.Prepare(&cfg.Matrix.JetStream) js := jetstream.Prepare(&cfg.Matrix.JetStream)
return internal.NewRoomserverAPI( return internal.NewRoomserverAPI(
cfg, roomserverDB, js, base.ProcessContext, cfg, roomserverDB, js,
cfg.Matrix.JetStream.TopicFor(jetstream.InputRoomEvent), cfg.Matrix.JetStream.TopicFor(jetstream.InputRoomEvent),
cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent), cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent),
base.Caches, perspectiveServerNames, base.Caches, perspectiveServerNames,

View file

@ -127,6 +127,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" + const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)"
const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE"
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
@ -147,6 +150,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
bulkSelectUnsentEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDsForEventNIDsStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt
} }
@ -173,6 +177,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
}.Prepare(db) }.Prepare(db)
@ -458,10 +463,28 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
return results, nil return results, nil
} }
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
}
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
var stmt *sql.Stmt
if onlyUnsent {
stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt)
} else {
stmt = sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
} }
// Look up the NID of the new join event // Look up the NID of the new join event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil { if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err) return fmt.Errorf("u.d.EventNIDs: %w", err)
} }
@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
} }
// Look up the NID of the new leave event // Look up the NID of the new leave event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil { if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err) return fmt.Errorf("u.d.EventNIDs: %w", err)
} }
@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er
} }
if u.membership != tables.MembershipStateKnock { if u.membership != tables.MembershipStateKnock {
// Look up the NID of the new knock event // Look up the NID of the new knock event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}) nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false)
if err != nil { if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err) return fmt.Errorf("u.d.EventNIDs: %w", err)
} }

View file

@ -215,7 +215,13 @@ func (u *RoomUpdater) EventIDs(
func (u *RoomUpdater) EventNIDs( func (u *RoomUpdater) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) { ) (map[string]types.EventNID, error) {
return u.d.eventNIDs(ctx, u.txn, eventIDs) return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter)
}
func (u *RoomUpdater) UnsentEventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly)
} }
func (u *RoomUpdater) StateAtEventIDs( func (u *RoomUpdater) StateAtEventIDs(
@ -231,7 +237,11 @@ func (u *RoomUpdater) StateEntriesForEventIDs(
} }
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs) return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
}
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
} }
func (u *RoomUpdater) GetMembershipEventNIDsForRoom( func (u *RoomUpdater) GetMembershipEventNIDsForRoom(

View file

@ -59,23 +59,12 @@ func (d *Database) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string, ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID) result := make(map[string]types.EventTypeNID)
remaining := []string{} nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes)
for _, eventType := range eventTypes { if err != nil {
if nid, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok { return nil, err
result[eventType] = nid
} else {
remaining = append(remaining, eventType)
}
} }
if len(remaining) > 0 { for eventType, nid := range nids {
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining) result[eventType] = nid
if err != nil {
return nil, err
}
for eventType, nid := range nids {
result[eventType] = nid
d.Cache.StoreRoomServerEventTypeNID(eventType, nid)
}
} }
return result, nil return result, nil
} }
@ -96,23 +85,12 @@ func (d *Database) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string, ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID) result := make(map[string]types.EventStateKeyNID)
remaining := []string{} nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
for _, eventStateKey := range eventStateKeys { if err != nil {
if nid, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok { return nil, err
result[eventStateKey] = nid
} else {
remaining = append(remaining, eventStateKey)
}
} }
if len(remaining) > 0 { for eventStateKey, nid := range nids {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining) result[eventStateKey] = nid
if err != nil {
return nil, err
}
for eventStateKey, nid := range nids {
result[eventStateKey] = nid
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, nid)
}
} }
return result, nil return result, nil
} }
@ -238,13 +216,27 @@ func (d *Database) addState(
func (d *Database) EventNIDs( func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) { ) (map[string]types.EventNID, error) {
return d.eventNIDs(ctx, nil, eventIDs) return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
} }
type UnsentFilter bool
const (
NoFilter UnsentFilter = false
FilterUnsentOnly UnsentFilter = true
)
func (d *Database) eventNIDs( func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventNID, error) { ) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) switch filter {
case FilterUnsentOnly:
return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs)
case NoFilter:
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
default:
panic("impossible case")
}
} }
func (d *Database) SetState( func (d *Database) SetState(
@ -281,11 +273,11 @@ func (d *Database) EventIDs(
} }
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return d.eventsFromIDs(ctx, nil, eventIDs) return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter)
} }
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) { func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs) nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -704,9 +696,6 @@ func (d *Database) assignRoomNID(
func (d *Database) assignEventTypeNID( func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string, ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) { ) (types.EventTypeNID, error) {
if eventTypeNID, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok {
return eventTypeNID, nil
}
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -717,18 +706,12 @@ func (d *Database) assignEventTypeNID(
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
} }
} }
if err == nil {
d.Cache.StoreRoomServerEventTypeNID(eventType, eventTypeNID)
}
return eventTypeNID, err return eventTypeNID, err
} }
func (d *Database) assignStateKeyNID( func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
if eventStateKeyNID, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok {
return eventStateKeyNID, nil
}
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -739,9 +722,6 @@ func (d *Database) assignStateKeyNID(
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
} }
} }
if err == nil {
d.Cache.StoreRoomServerStateKeyNID(eventStateKey, eventStateKeyNID)
}
return eventStateKeyNID, err return eventStateKeyNID, err
} }

View file

@ -99,6 +99,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" + const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)"
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
@ -118,8 +121,9 @@ type eventStatements struct {
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt //bulkSelectEventNIDStmt *sql.Stmt
//selectRoomNIDsForEventNIDsStmt *sql.Stmt //bulkSelectUnsentEventNIDStmt *sql.Stmt
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
} }
func createEventsTable(db *sql.DB) error { func createEventsTable(db *sql.DB) error {
@ -144,7 +148,8 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, //{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
//{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -494,15 +499,33 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
return results, nil return results, nil
} }
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
}
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs { for k, v := range eventIDs {
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) var selectOrig string
if onlyUnsent {
selectOrig = strings.Replace(bulkSelectUnsentEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
} else {
selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
}
selectStmt, err := s.db.Prepare(selectOrig) selectStmt, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -59,6 +59,7 @@ type Events interface {
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
} }

View file

@ -39,7 +39,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -274,8 +274,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() accounts.Database { func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS) db, err := userdb.NewDatabase(
&b.Cfg.UserAPI.AccountDatabase,
b.Cfg.Global.ServerName,
b.Cfg.UserAPI.BCryptCost,
b.Cfg.UserAPI.OpenIDTokenLifetimeMS,
userapi.DefaultLoginTokenLifetime,
)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db") logrus.WithError(err).Panicf("failed to connect to accounts db")
} }

View file

@ -18,6 +18,10 @@ type ClientAPI struct {
// If set, allows registration by anyone who also has the shared // If set, allows registration by anyone who also has the shared
// secret, even if registration is otherwise disabled. // secret, even if registration is otherwise disabled.
RegistrationSharedSecret string `yaml:"registration_shared_secret"` RegistrationSharedSecret string `yaml:"registration_shared_secret"`
// If set, prevents guest accounts from being created. Only takes
// effect if registration is enabled, otherwise guests registration
// is forbidden either way.
GuestsDisabled bool `yaml:"guests_disabled"`
// Boolean stating whether catpcha registration is enabled // Boolean stating whether catpcha registration is enabled
// and required // and required

View file

@ -29,8 +29,6 @@ type FederationAPI struct {
// on remote federation endpoints. This is not recommended in production! // on remote federation endpoints. This is not recommended in production!
DisableTLSValidation bool `yaml:"disable_tls_validation"` DisableTLSValidation bool `yaml:"disable_tls_validation"`
Proxy Proxy `yaml:"proxy_outbound"`
// Perspective keyservers, to use as a backup when direct key fetch // Perspective keyservers, to use as a backup when direct key fetch
// requests don't succeed // requests don't succeed
KeyPerspectives KeyPerspectives `yaml:"key_perspectives"` KeyPerspectives KeyPerspectives `yaml:"key_perspectives"`
@ -50,8 +48,6 @@ func (c *FederationAPI) Defaults(generate bool) {
c.FederationMaxRetries = 16 c.FederationMaxRetries = 16
c.DisableTLSValidation = false c.DisableTLSValidation = false
c.Proxy.Defaults()
} }
func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {

View file

@ -65,6 +65,9 @@ type Global struct {
// DNS caching options for all outbound HTTP requests // DNS caching options for all outbound HTTP requests
DNSCache DNSCacheOptions `yaml:"dns_cache"` DNSCache DNSCacheOptions `yaml:"dns_cache"`
// ServerNotices configuration used for sending server notices
ServerNotices ServerNotices `yaml:"server_notices"`
// Consent tracking options // Consent tracking options
UserConsentOptions UserConsentOptions `yaml:"user_consent"` UserConsentOptions UserConsentOptions `yaml:"user_consent"`
} }
@ -84,6 +87,7 @@ func (c *Global) Defaults(generate bool) {
c.DNSCache.Defaults() c.DNSCache.Defaults()
c.Sentry.Defaults() c.Sentry.Defaults()
c.UserConsentOptions.Defaults(c.BaseURL) c.UserConsentOptions.Defaults(c.BaseURL)
c.ServerNotices.Defaults(generate)
} }
func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
@ -95,6 +99,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.Sentry.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith)
c.DNSCache.Verify(configErrs, isMonolith) c.DNSCache.Verify(configErrs, isMonolith)
c.UserConsentOptions.Verify(configErrs, isMonolith) c.UserConsentOptions.Verify(configErrs, isMonolith)
c.ServerNotices.Verify(configErrs, isMonolith)
} }
type OldVerifyKeys struct { type OldVerifyKeys struct {
@ -136,6 +141,31 @@ func (c *Metrics) Defaults(generate bool) {
func (c *Metrics) Verify(configErrs *ConfigErrors, isMonolith bool) { func (c *Metrics) Verify(configErrs *ConfigErrors, isMonolith bool) {
} }
// ServerNotices defines the configuration used for sending server notices
type ServerNotices struct {
Enabled bool `yaml:"enabled"`
// The localpart to be used when sending notices
LocalPart string `yaml:"local_part"`
// The displayname to be used when sending notices
DisplayName string `yaml:"display_name"`
// The avatar of this user
AvatarURL string `yaml:"avatar"`
// The roomname to be used when creating messages
RoomName string `yaml:"room_name"`
}
func (c *ServerNotices) Defaults(generate bool) {
if generate {
c.Enabled = true
c.LocalPart = "_server"
c.DisplayName = "Server Alert"
c.RoomName = "Server Alert"
c.AvatarURL = ""
}
}
func (c *ServerNotices) Verify(errors *ConfigErrors, isMonolith bool) {}
// The configuration to use for Sentry error reporting // The configuration to use for Sentry error reporting
type Sentry struct { type Sentry struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`

View file

@ -58,6 +58,11 @@ global:
basic_auth: basic_auth:
username: metrics username: metrics
password: metrics password: metrics
server_notices:
local_part: "_server"
display_name: "Server alerts"
avatar: ""
room_name: "Server Alerts"
app_service_api: app_service_api:
internal_api: internal_api:
listen: http://localhost:7777 listen: http://localhost:7777
@ -118,11 +123,6 @@ federation_sender:
conn_max_lifetime: -1 conn_max_lifetime: -1
send_max_retries: 16 send_max_retries: 16
disable_tls_validation: false disable_tls_validation: false
proxy_outbound:
enabled: false
protocol: http
host: localhost
port: 8080
key_server: key_server:
internal_api: internal_api:
listen: http://localhost:7779 listen: http://localhost:7779

View file

@ -16,9 +16,6 @@ type UserAPI struct {
// The Account database stores the login details and account information // The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI. // for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"` AccountDatabase DatabaseOptions `yaml:"account_database"`
// The Device database stores session information for the devices of logged
// in local users. It is accessed by the UserAPI.
DeviceDatabase DatabaseOptions `yaml:"device_database"`
} }
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) {
c.InternalAPI.Listen = "http://localhost:7781" c.InternalAPI.Listen = "http://localhost:7781"
c.InternalAPI.Connect = "http://localhost:7781" c.InternalAPI.Connect = "http://localhost:7781"
c.AccountDatabase.Defaults(10) c.AccountDatabase.Defaults(10)
c.DeviceDatabase.Defaults(10)
if generate { if generate {
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
} }
c.BCryptCost = bcrypt.DefaultCost c.BCryptCost = bcrypt.DefaultCost
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
} }

View file

@ -24,13 +24,12 @@ func Prepare(cfg *config.JetStream) natsclient.JetStreamContext {
if natsServer == nil { if natsServer == nil {
var err error var err error
natsServer, err = natsserver.NewServer(&natsserver.Options{ natsServer, err = natsserver.NewServer(&natsserver.Options{
ServerName: "monolith", ServerName: "monolith",
DontListen: true, DontListen: true,
JetStream: true, JetStream: true,
StoreDir: string(cfg.StoragePath), StoreDir: string(cfg.StoragePath),
NoSystemAccount: true, NoSystemAccount: true,
AllowNewAccounts: false, MaxPayload: 16 * 1024 * 1024,
MaxPayload: 16 * 1024 * 1024,
}) })
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/dendrite/syncapi"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -38,7 +38,7 @@ import (
// all components of Dendrite, for use in monolith mode. // all components of Dendrite, for use in monolith mode.
type Monolith struct { type Monolith struct {
Config *config.Dendrite Config *config.Dendrite
AccountDB accounts.Database AccountDB userdb.Database
KeyRing *gomatrixserverlib.KeyRing KeyRing *gomatrixserverlib.KeyRing
Client *gomatrixserverlib.Client Client *gomatrixserverlib.Client
FedClient *gomatrixserverlib.FederationClient FedClient *gomatrixserverlib.FederationClient

View file

@ -16,6 +16,7 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -307,7 +308,9 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
ctx context.Context, msg api.OutputRetireInviteEvent, ctx context.Context, msg api.OutputRetireInviteEvent,
) { ) {
pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID) pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID)
if err != nil { // It's possible we just haven't heard of this invite yet, so
// we should not panic if we try to retire it.
if err != nil && err != sql.ErrNoRows {
sentry.CaptureException(err) sentry.CaptureException(err)
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{

View file

@ -39,14 +39,14 @@ func Setup(
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
cfg *config.SyncAPI, cfg *config.SyncAPI,
) { ) {
r0mux := csMux.PathPrefix("/r0").Subrouter() v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
// TODO: Add AS support for all handlers below. // TODO: Add AS support for all handlers below.
r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device) return srp.OnIncomingSyncRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -54,7 +54,7 @@ func Setup(
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp) return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter", v3mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("put_filter", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -64,7 +64,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}", v3mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_filter", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
@ -74,7 +74,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, cfg.Matrix.UserConsentOptions, false, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingKeyChangeRequest(req, device) return srp.OnIncomingKeyChangeRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
} }

View file

@ -279,7 +279,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
parts := strings.Split(tok[1:], "_") parts := strings.Split(tok[1:], "_")
var positions [7]StreamPosition var positions [7]StreamPosition
for i, p := range parts { for i, p := range parts {
if i > len(positions) { if i >= len(positions) {
break break
} }
var pos int var pos int

View file

@ -592,3 +592,4 @@ Forward extremities remain so even after the next events are populated as outlie
If a device list update goes missing, the server resyncs on the next one If a device list update goes missing, the server resyncs on the next one
uploading self-signing key notifies over federation uploading self-signing key notifies over federation
uploading signed devices gets propagated over federation uploading signed devices gets propagated over federation
Device list doesn't change if remote server is down

View file

@ -18,8 +18,9 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
// UserInternalAPI is the internal API for information about users and devices. // UserInternalAPI is the internal API for information about users and devices.
@ -384,6 +385,7 @@ type Device struct {
// If the device is for an appservice user, // If the device is for an appservice user,
// this is the appservice ID. // this is the appservice ID.
AppserviceID string AppserviceID string
AccountType AccountType
} }
// Account represents a Matrix account on this home server. // Account represents a Matrix account on this home server.
@ -392,7 +394,7 @@ type Account struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
AppServiceID string AppServiceID string
// TODO: Other flags like IsAdmin, IsGuest AccountType AccountType
// TODO: Associations (e.g. with application services) // TODO: Associations (e.g. with application services)
} }
@ -448,4 +450,8 @@ const (
AccountTypeUser AccountType = 1 AccountTypeUser AccountType = 1
// AccountTypeGuest indicates this is a guest account // AccountTypeGuest indicates this is a guest account
AccountTypeGuest AccountType = 2 AccountTypeGuest AccountType = 2
// AccountTypeAdmin indicates this is an admin account
AccountTypeAdmin AccountType = 3
// AccountTypeAppService indicates this is an appservice account
AccountTypeAppService AccountType = 4
) )

View file

@ -19,6 +19,13 @@ import (
"time" "time"
) )
// DefaultLoginTokenLifetime determines how old a valid token may be.
//
// NOTSPEC: The current spec says "SHOULD be limited to around five
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
const DefaultLoginTokenLifetime = 2 * time.Minute
type LoginTokenInternalAPI interface { type LoginTokenInternalAPI interface {
// PerformLoginTokenCreation creates a new login token and associates it with the provided data. // PerformLoginTokenCreation creates a new login token and associates it with the provided data.
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error

View file

@ -21,22 +21,21 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
type UserInternalAPI struct { type UserInternalAPI struct {
AccountDB accounts.Database DB storage.Database
DeviceDB devices.Database
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS // AppServices is the list of all registered AS
AppServices []config.ApplicationService AppServices []config.ApplicationService
@ -54,10 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" { if req.DataType == "" {
return fmt.Errorf("data type must not be empty") return fmt.Errorf("data type must not be empty")
} }
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if req.AccountType == api.AccountTypeGuest { if req.AccountType == api.AccountTypeGuest {
acc, err := a.AccountDB.CreateGuestAccount(ctx) acc, err := a.AccountDB.CreateGuestAccount(ctx)
if err != nil { if err != nil {
@ -86,11 +86,18 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
Localpart: req.Localpart, Localpart: req.Localpart,
ServerName: a.ServerName, ServerName: a.ServerName,
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
AccountType: req.AccountType,
} }
return nil return nil
} }
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if req.AccountType == api.AccountTypeGuest {
res.AccountCreated = true
res.Account = acc
return nil
}
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err return err
} }
@ -100,7 +107,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
} }
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err return err
} }
res.PasswordUpdated = true res.PasswordUpdated = true
@ -113,7 +120,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID, "device_id": req.DeviceID,
"display_name": req.DeviceDisplayName, "display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation") }).Info("PerformDeviceCreation")
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil { if err != nil {
return err return err
} }
@ -138,12 +145,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 { if len(req.DeviceIDs) == 0 {
var devices []api.Device var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices { for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID) deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
} }
} else { } else {
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
} }
if err != nil { if err != nil {
return err return err
@ -197,7 +204,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil { if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
} }
return nil return nil
@ -209,7 +216,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err return err
} }
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
res.DeviceExists = false res.DeviceExists = false
return nil return nil
@ -224,7 +231,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil return nil
} }
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err return err
@ -262,7 +269,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
} }
prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -276,7 +283,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
} }
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil { if err != nil {
return err return err
} }
@ -285,7 +292,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
} }
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil { if err != nil {
return err return err
} }
@ -313,10 +320,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
} }
devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil { if err != nil {
return err return err
} }
res.UserExists = true
res.Devices = devs res.Devices = devs
return nil return nil
} }
@ -331,7 +339,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
if req.DataType != "" { if req.DataType != "" {
var data json.RawMessage var data json.RawMessage
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil { if err != nil {
return err return err
} }
@ -349,7 +357,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
return nil return nil
} }
global, rooms, err := a.AccountDB.GetAccountData(ctx, local) global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil { if err != nil {
return err return err
} }
@ -368,13 +376,22 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil return nil
} }
device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
} }
return err return err
} }
localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return err
}
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
device.AccountType = acc.AccountType
res.Device = device res.Device = device
return nil return nil
} }
@ -401,6 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// AS dummy device has AS's token. // AS dummy device has AS's token.
AccessToken: token, AccessToken: token,
AppserviceID: appService.ID, AppserviceID: appService.ID,
AccountType: api.AccountTypeAppService,
} }
localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName)
@ -410,7 +428,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered // Verify that the user is registered
account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or // Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces // it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -428,7 +446,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
err := a.AccountDB.DeactivateAccount(ctx, req.Localpart) err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil res.AccountDeactivated = err == nil
return err return err
} }
@ -437,7 +455,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24) token := util.RandomString(24)
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID) exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{ res.Token = api.OpenIDToken{
Token: token, Token: token,
@ -450,7 +468,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token) openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil { if err != nil {
return err return err
} }
@ -472,7 +490,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
return nil return nil
} }
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err) res.Error = fmt.Sprintf("failed to delete backup: %s", err)
} }
@ -485,7 +503,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
// Create metadata // Create metadata
if req.Version == "" { if req.Version == "" {
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err) res.Error = fmt.Sprintf("failed to create backup: %s", err)
} }
@ -498,7 +516,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
// Update metadata // Update metadata
if len(req.Keys.Rooms) == 0 { if len(req.Keys.Rooms) == 0 {
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err) res.Error = fmt.Sprintf("failed to update backup: %s", err)
} }
@ -519,7 +537,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version // you can only upload keys for the CURRENT version
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err) res.Error = fmt.Sprintf("failed to query version: %s", err)
return return
@ -547,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}) })
} }
} }
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err) res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return return
@ -557,7 +575,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
} }
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version res.Version = version
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -573,14 +591,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted res.Exists = !deleted
if !req.ReturnKeys { if !req.ReturnKeys {
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err) res.Error = fmt.Sprintf("failed to count keys: %s", err)
} }
return return
} }
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err) res.Error = fmt.Sprintf("failed to query keys: %s", err)
return return

View file

@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
} }
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil { if err != nil {
return err return err
} }
@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
// PerformLoginTokenDeletion ensures the token doesn't exist. // PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion") util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
return a.DeviceDB.RemoveLoginToken(ctx, req.Token) return a.DB.RemoveLoginToken(ctx, req.Token)
} }
// QueryLoginToken returns the data associated with a login token. If // QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil. // the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil { if err != nil {
res.Data = nil res.Data = nil
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
} }
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil res.Data = nil
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil

View file

@ -1,547 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
deltas.LoadAddPolicyVersion(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
return err
})
return acc, err
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, sqlutil.ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string,
) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string,
) (*api.Account, error) {
var account *api.Account
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = d.hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists
}
return nil, err
}
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)); err != nil {
return nil, err
}
return account, nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart)
}
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
if err != nil {
return err
}
if deleted {
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
if err != nil {
return err
}
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
}
}
// if we shouldn't replace the key we do nothing with it
continue
}
}
// if we're here, either the room or session are new, either way, we insert
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
}
}
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
}
// GetOutdatedPolicy queries all users which didn't accept the current policy version.
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
}
// UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
})
return
}

View file

@ -1,52 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}

View file

@ -1,270 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,271 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib"
)
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// Database represents a device database.
type Database struct {
db *sql.DB
writer sqlutil.Writer
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
var d devicesStatements
var lt loginTokenStatements
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
if err = lt.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
if err = lt.prepare(db); err != nil {
return nil, err
}
return &Database{db, writer, d, lt, loginTokenLifetime}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
}

View file

@ -1,42 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters. loginTokenLifetime determines how long a
// login token from CreateLoginToken is valid.
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -1,39 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"fmt"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(
dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName,
loginTokenLifetime time.Duration,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package accounts package storage
import ( import (
"context" "context"
@ -32,8 +32,7 @@ type Database interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string) (*api.Account, error) CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, policyVersion string, accountType api.AccountType) (*api.Account, error)
CreateGuestAccount(ctx context.Context) (*api.Account, error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
@ -64,6 +63,35 @@ type Database interface {
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
RemoveLoginToken(ctx context.Context, token string) error
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -56,19 +57,20 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
_, err = db.Exec(accountDataSchema) s := &accountDataStatements{}
_, err := db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.insertAccountDataStmt, insertAccountDataSQL},
{&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL},
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData(
return return
} }
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
return global, rooms, rows.Err() return global, rooms, rows.Err()
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte

View file

@ -19,10 +19,12 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -40,17 +42,19 @@ CREATE TABLE IF NOT EXISTS account_accounts (
appservice_id TEXT, appservice_id TEXT,
-- If the account is currently active -- If the account is currently active
is_deactivated BOOLEAN DEFAULT FALSE, is_deactivated BOOLEAN DEFAULT FALSE,
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type SMALLINT NOT NULL,
-- The policy version this user has accepted -- The policy version this user has accepted
policy_version TEXT policy_version TEXT
-- TODO: -- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
-- Create sequence for autogenerated numeric usernames -- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -59,7 +63,7 @@ const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
@ -89,14 +93,15 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) execSchema(db *sql.DB) error { func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
s := &accountsStatements{
serverName: serverName,
}
_, err := db.Exec(accountsSchema) _, err := db.Exec(accountsSchema)
return err if err != nil {
} return nil, err
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { return s, sqlutil.StatementList{
s.serverName = server
return sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL}, {&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL}, {&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL},
@ -112,17 +117,17 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, policyVersion) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, "") _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "")
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -133,38 +138,39 @@ func (s *accountsStatements) insertAccount(
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName, ServerName: s.serverName,
AppServiceID: appserviceID, AppServiceID: appserviceID,
AccountType: accountType,
}, nil }, nil
} }
func (s *accountsStatements) updatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart, passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return return
} }
func (s *accountsStatements) deactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return return
} }
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return return
} }
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
var acc api.Account var acc api.Account
stmt := s.selectAccountByLocalpartStmt stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db") log.WithError(err).Error("Unable to retrieve user from the db")
@ -181,7 +187,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil return &acc, nil
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt

View file

@ -4,12 +4,14 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose" "github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive) goose.AddMigration(UpIsActive, DownIsActive)
goose.AddMigration(UpAddAccountType, DownAddAccountType)
goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion) goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion)
} }

View file

@ -5,13 +5,8 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) { func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
} }

View file

@ -0,0 +1,34 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadAddAccountType(m *sqlutil.Migrations) {
m.AddMigration(UpAddAccountType, DownAddAccountType)
}
func UpAddAccountType(tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.Exec(`ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddAccountType(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN account_type;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -111,50 +112,32 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) execSchema(db *sql.DB) error { func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
serverName: serverName,
}
_, err := db.Exec(devicesSchema) _, err := db.Exec(devicesSchema)
return err if err != nil {
} return nil, err
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
} }
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { return s, sqlutil.StatementList{
return {&s.insertDeviceStmt, insertDeviceSQL},
} {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { {&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
return {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
} {&s.updateDeviceNameStmt, updateDeviceNameSQL},
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { {&s.deleteDeviceStmt, deleteDeviceSQL},
return {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
} {&s.deleteDevicesStmt, deleteDevicesSQL},
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { {&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
return {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
} }.Prepare(db)
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
} }
// insertDevice creates a new device. Returns an error if any device with the same access token already exists. // insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice( func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
@ -176,7 +159,7 @@ func (s *devicesStatements) insertDevice(
} }
// deleteDevice removes a single device by id and user localpart. // deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -186,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
// deleteDevices removes a single or multiple devices by ids and user localpart. // deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed. // Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
@ -196,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the // deleteDevicesByLocalpart removes all devices for the
// given user localpart. // given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -204,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err return err
} }
func (s *devicesStatements) updateDeviceName( func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -212,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
return err return err
} }
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string, ctx context.Context, accessToken string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -228,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user // selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID // localpart and deviceID
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -245,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err return &dev, err
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil { if err != nil {
return nil, err return nil, err
@ -268,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
@ -310,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000 lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupTableSchema = ` const keyBackupTableSchema = `
@ -72,12 +73,13 @@ type keyBackupStatements struct {
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
} }
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
_, err = db.Exec(keyBackupTableSchema) s := &keyBackupStatements{}
_, err := db.Exec(keyBackupTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL}, {&s.countKeysStmt, countKeysSQL},
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s keyBackupStatements) countKeys( func (s keyBackupStatements) CountKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) { ) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return return
} }
func (s *keyBackupStatements) insertBackupKey( func (s *keyBackupStatements) InsertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
return return
} }
func (s *keyBackupStatements) updateBackupKey( func (s *keyBackupStatements) UpdateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
return return
} }
func (s *keyBackupStatements) selectKeys( func (s *keyBackupStatements) SelectKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomID( func (s *keyBackupStatements) SelectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string, ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)

View file

@ -22,6 +22,7 @@ import (
"strconv" "strconv"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupVersionTableSchema = ` const keyBackupVersionTableSchema = `
@ -69,12 +70,13 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt updateKeyBackupETagStmt *sql.Stmt
} }
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
_, err = db.Exec(keyBackupVersionTableSchema) s := &keyBackupVersionStatements{}
_, err := db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.insertKeyBackupStmt, insertKeyBackupSQL},
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *keyBackupVersionStatements) insertKeyBackup( func (s *keyBackupVersionStatements) InsertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) { ) (version string, err error) {
var versionInt int64 var versionInt int64
@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return strconv.FormatInt(versionInt, 10), err return strconv.FormatInt(versionInt, 10), err
} }
func (s *keyBackupVersionStatements) updateKeyBackupAuthData( func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err return err
} }
func (s *keyBackupVersionStatements) updateKeyBackupETag( func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string, ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
return err return err
} }
func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) DeleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) { ) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
return ra == 1, nil return ra == 1, nil
} }
func (s *keyBackupVersionStatements) selectKeyBackup( func (s *keyBackupVersionStatements) SelectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64 var versionInt int64

View file

@ -21,18 +21,11 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type loginTokenStatements struct { const loginTokenSchema = `
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt
}
// execSchema ensures tables and indices exist.
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS login_tokens ( CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user -- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
@ -45,21 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens (
-- This index allows efficient garbage collection of expired tokens. -- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`) `
return err
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectStmt *sql.Stmt
} }
// prepare runs statement preparation. func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
func (s *loginTokenStatements) prepare(db *sql.DB) error { s := &loginTokenStatements{}
return sqlutil.StatementList{ _, err := db.Exec(loginTokenSchema)
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, if err != nil {
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, return nil, err
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, }
return s, sqlutil.StatementList{
{&s.insertStmt, insertLoginTokenSQL},
{&s.deleteStmt, deleteLoginTokenSQL},
{&s.selectStmt, selectLoginTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insert adds an already generated token to the database. // insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt) stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err return err
@ -69,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
// //
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient. // The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt) stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil { if err != nil {
@ -82,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
} }
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. // selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
); );
` `
const insertTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct { type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
_, err = db.Exec(openIDTokenSchema) s := &openIDTokenStatements{
if err != nil { serverName: serverName,
return
} }
s.serverName = server _, err := db.Exec(openIDTokenSchema)
return sqlutil.StatementList{ if err != nil {
{&s.insertTokenStmt, insertTokenSQL}, return nil, err
{&s.selectTokenStmt, selectTokenSQL}, }
return s, sqlutil.StatementList{
{&s.insertTokenStmt, insertOpenIDTokenSQL},
{&s.selectTokenStmt, selectOpenIDTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insertToken inserts a new OpenID Connect token to the DB. // insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if the token already exists. // Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string,
@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken(
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found // Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributes( func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
ctx context.Context, ctx context.Context,
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const profilesSchema = ` const profilesSchema = `
@ -59,12 +60,13 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
_, err = db.Exec(profilesSchema) s := &profilesStatements{}
_, err := db.Exec(profilesSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertProfileStmt, insertProfileSQL}, {&s.insertProfileStmt, insertProfileSQL},
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
{&s.setAvatarURLStmt, setAvatarURLSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL},
@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }
func (s *profilesStatements) selectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile var profile authtypes.Profile
@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart(
return &profile, nil return &profile, nil
} }
func (s *profilesStatements) setAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return return
} }
func (s *profilesStatements) setDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return return
} }
func (s *profilesStatements) selectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(
ctx context.Context, searchString string, limit int, ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) { ) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile var profiles []authtypes.Profile

View file

@ -0,0 +1,105 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
if _, err = db.Exec(accountsSchema); err != nil {
// do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err
}
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
}
devicesTable, err := NewPostgresDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
}
keyBackupTable, err := NewPostgresKeyBackupTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err)
}
keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err)
}
loginTokenTable, err := NewPostgresLoginTokenTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err)
}
openIDTable, err := NewPostgresOpenIDTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err)
}
profilesTable, err := NewPostgresProfilesTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err)
}
threePIDTable, err := NewPostgresThreePIDTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
Devices: devicesTable,
KeyBackups: keyBackupTable,
KeyBackupVersions: keyBackupVersionTable,
LoginTokens: loginTokenTable,
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewDummyWriter(),
LoginTokenLifetime: loginTokenLifetime,
BcryptCost: bcryptCost,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}

Some files were not shown because too many files have changed in this diff Show more