Pull latest dendrite fork into harmony (#252)

Latest dendrite main has changes for knockable rooms, and the fix for login crash. Pulled into dendrite fork. Rebased dendrite fork from dendrite main.

Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Co-authored-by: kegsay <kegan@matrix.org>
Co-authored-by: Tak Wai Wong <tak@hntlabs.com>
Co-authored-by: texuf <texuf.eth@gmail.com>
Co-authored-by: Brian Meek <brian@hntlabs.com>
Co-authored-by: Tak Wai Wong <takwaiw@gmail.com>
This commit is contained in:
Tak Wai Wong 2022-08-13 16:09:49 -07:00 committed by GitHub
parent 8a66a55771
commit ea0a8804d2
98 changed files with 2939 additions and 2818 deletions

View file

@ -19,10 +19,10 @@ jobs:
runs-on: ubuntu-latest
if: ${{ false }} # disable for now
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: 1.18
@ -66,8 +66,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v3
with:
go-version: 1.18
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
uses: golangci/golangci-lint-action@v3
# run go test with different go versions
test:
@ -101,7 +105,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- uses: actions/cache@v3
@ -133,7 +137,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Install dependencies x86
@ -167,7 +171,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Install dependencies
@ -208,7 +212,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: "1.18"
- uses: actions/cache@v3
@ -233,7 +237,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: "1.18"
- uses: actions/cache@v3

View file

@ -1,5 +1,30 @@
# Changelog
## Dendrite 0.9.2 (2022-08-12)
### Features
* Dendrite now supports history visibility on the `/sync`, `/messages` and `/context` endpoints
* It should now be possible to view the history of a room in more cases (as opposed to limiting scrollback to the join event or defaulting to the restrictive `"join"` visibility rule as before)
* The default room version for newly created rooms is now room version 9
* New admin endpoint `/_dendrite/admin/resetPassword/{userID}` has been added, which replaces the `-reset-password` flag in `create-account`
* The `create-account` binary now uses shared secret registration over HTTP to create new accounts, which fixes a number of problems with account data and push rules not being configured correctly for new accounts
* The internal HTTP APIs for polylith deployments have been refactored for correctness and consistency
* The federation API will now automatically clean up some EDUs that have failed to send within a certain period of time
* The `/hierarchy` endpoint will now return potentially joinable rooms (contributed by [texuf](https://github.com/texuf))
* The user directory will now show or hide users correctly
### Fixes
* Send-to-device messages should no longer be incorrectly duplicated in `/sync`
* The federation sender will no longer create unnecessary destination queues as a result of a logic error
* A bug where database migrations may not execute properly when upgrading from older versions has been fixed
* A crash when failing to update user account data has been fixed
* A race condition when generating notification counts has been fixed
* A race condition when setting up NATS has been fixed (contributed by [brianathere](https://github.com/brianathere))
* Stale cache data for membership lazy-loading is now correctly invalidated when doing a complete sync
* Data races within user-interactive authentication have been fixed (contributed by [tak-hntlabs](https://github.com/tak-hntlabs))
## Dendrite 0.9.1 (2022-08-03)
### Fixes
@ -10,7 +35,7 @@
* The media endpoint now sets the `Cache-Control` header correctly to prevent web-based clients from hitting media endpoints excessively
* The sync API will now advance the PDU stream position correctly in all cases (contributed by [sergekh2](https://github.com/sergekh2))
* The sync API will now delete the correct range of send-to-device messages when advancing the stream position
* The device list `changed` key in the `/sync` response should now return the correct users
* The device list `changed` key in the `/sync` response should now return the correct users
* A data race when looking up missing state has been fixed
* The `/send_join` API is now applying stronger validation to the received membership event

View file

@ -80,7 +80,7 @@ $ ./bin/dendrite-monolith-server --tls-cert server.crt --tls-key server.key --co
# Create an user account (add -admin for an admin user).
# Specify the localpart only, e.g. 'alice' for '@alice:domain.com'
$ ./bin/create-account --config dendrite.yaml -username alice
$ ./bin/create-account --config dendrite.yaml --url http://localhost:8008 --username alice
```
Then point your favourite Matrix client at `http://localhost:8008` or `https://localhost:8448`.
@ -89,12 +89,12 @@ Then point your favourite Matrix client at `http://localhost:8008` or `https://l
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
updates with CI. As of August 2022 we're at around 83% CS API coverage and 95% Federation coverage, though check
updates with CI. As of August 2022 we're at around 90% CS API coverage and 95% Federation coverage, though check
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
servers such as matrix.org reasonably well, although there are still some missing features (like Search).
We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather
than features that massive deployments may be interested in (User Directory, OpenID, Guests, Admin APIs, AS API).
than features that massive deployments may be interested in (OpenID, Guests, Admin APIs, AS API).
This means Dendrite supports amongst others:
- Core room functionality (creating rooms, invites, auth rules)

View file

@ -7,7 +7,6 @@ import (
"github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@ -42,11 +41,10 @@ func (h *httpAppServiceQueryAPI) RoomAliasExists(
request *api.RoomAliasExistsRequest,
response *api.RoomAliasExistsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists")
defer span.Finish()
apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"RoomAliasExists", h.appserviceURL+AppServiceRoomAliasExistsPath,
h.httpClient, ctx, request, response,
)
}
// UserIDExists implements AppServiceQueryAPI
@ -55,9 +53,8 @@ func (h *httpAppServiceQueryAPI) UserIDExists(
request *api.UserIDExistsRequest,
response *api.UserIDExistsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceUserIDExists")
defer span.Finish()
apiURL := h.appserviceURL + AppServiceUserIDExistsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"UserIDExists", h.appserviceURL+AppServiceUserIDExistsPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -1,43 +1,20 @@
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/util"
)
// AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux.
func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
AppServiceRoomAliasExistsPath,
httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse {
var request api.RoomAliasExistsRequest
var response api.RoomAliasExistsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := a.RoomAliasExists(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists),
)
internalAPIMux.Handle(
AppServiceUserIDExistsPath,
httputil.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse {
var request api.UserIDExistsRequest
var response api.UserIDExistsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := a.UserIDExists(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists),
)
}

View file

@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -41,7 +42,7 @@ func LoginFromJSONReader(
userInteractiveAuth *UserInteractive,
cfg *config.ClientAPI,
) (*Login, LoginCleanupFunc, *util.JSONResponse) {
reqBytes, err := io.ReadAll(r)
reqBytes, err := ioutil.ReadAll(r)
if err != nil {
err := &util.JSONResponse{
Code: http.StatusBadRequest,

View file

@ -15,11 +15,13 @@
package jsonerror
import (
"context"
"fmt"
"net/http"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
// MatrixError represents the "standard error response" in Matrix.
@ -213,3 +215,15 @@ func NotTrusted(serverName string) *MatrixError {
Err: fmt.Sprintf("Untrusted server '%s'", serverName),
}
}
// InternalAPIError is returned when Dendrite failed to reach an internal API.
func InternalAPIError(ctx context.Context, err error) util.JSONResponse {
logrus.WithContext(ctx).WithError(err).Error("Error reaching an internal API")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: &MatrixError{
ErrCode: "M_INTERNAL_SERVER_ERROR",
Err: "Dendrite encountered an error reaching an internal API.",
},
}
}

View file

@ -1,23 +1,20 @@
package routing
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -30,13 +27,15 @@ func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserv
}
}
res := &roomserverAPI.PerformAdminEvacuateRoomResponse{}
rsAPI.PerformAdminEvacuateRoom(
if err := rsAPI.PerformAdminEvacuateRoom(
req.Context(),
&roomserverAPI.PerformAdminEvacuateRoomRequest{
RoomID: roomID,
},
res,
)
); err != nil {
return util.ErrorResponse(err)
}
if err := res.Error; err != nil {
return err.JSONResponse()
}
@ -48,13 +47,7 @@ func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserv
}
}
func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -66,14 +59,26 @@ func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserv
JSON: jsonerror.MissingArgument("Expecting user ID."),
}
}
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if domain != cfg.Matrix.ServerName {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("User ID must belong to this server."),
}
}
res := &roomserverAPI.PerformAdminEvacuateUserResponse{}
rsAPI.PerformAdminEvacuateUser(
if err := rsAPI.PerformAdminEvacuateUser(
req.Context(),
&roomserverAPI.PerformAdminEvacuateUserRequest{
UserID: userID,
},
res,
)
); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := res.Error; err != nil {
return err.JSONResponse()
}
@ -84,3 +89,52 @@ func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserv
},
}
}
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
localpart, ok := vars["localpart"]
if !ok {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting user localpart."),
}
}
request := struct {
Password string `json:"password"`
}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
}
}
if request.Password == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting non-empty password."),
}
}
updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart,
Password: request.Password,
LogoutDevices: true,
}
updateRes := &userapi.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to perform password update: " + err.Error()),
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct {
Updated bool `json:"password_updated"`
}{
Updated: updateRes.PasswordUpdated,
},
}
}

View file

@ -556,10 +556,12 @@ func createRoom(
if r.Visibility == "public" {
// expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID,
Visibility: "public",
}, &pubRes)
}, &pubRes); err != nil {
return jsonerror.InternalAPIError(ctx, err)
}
if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")

View file

@ -302,10 +302,12 @@ func SetVisibility(
}
var publishRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID,
Visibility: v.Visibility,
}, &publishRes)
}, &publishRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse()

View file

@ -81,8 +81,9 @@ func JoinRoomByIDOrAlias(
done := make(chan util.JSONResponse, 1)
go func() {
defer close(done)
rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes)
if joinRes.Error != nil {
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
done <- jsonerror.InternalAPIError(req.Context(), err)
} else if joinRes.Error != nil {
done <- joinRes.Error.JSONResponse()
} else {
done <- util.JSONResponse{

View file

@ -91,10 +91,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID,
Version: version,
}, &queryResp)
}, &queryResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
}
@ -233,13 +235,15 @@ func GetBackupKeys(
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string,
) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID,
Version: version,
ReturnKeys: true,
KeysForRoomID: roomID,
KeysForSessionID: sessionID,
}, &queryResp)
}, &queryResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
}

View file

@ -72,7 +72,9 @@ func UploadCrossSigningDeviceKeys(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)
if err := keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := uploadRes.Error; err != nil {
switch {
@ -114,7 +116,9 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie
}
uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes)
if err := keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := uploadRes.Error; err != nil {
switch {

View file

@ -62,7 +62,9 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devi
}
var uploadRes api.PerformUploadKeysResponse
keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes)
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
return util.ErrorResponse(err)
}
if uploadRes.Error != nil {
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
return jsonerror.InternalServerError()
@ -107,12 +109,14 @@ func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devic
return *resErr
}
queryRes := api.QueryKeysResponse{}
keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
if err := keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
UserID: device.UserID,
UserToDevices: r.DeviceKeys,
Timeout: r.GetTimeout(),
// TODO: Token?
}, &queryRes)
}, &queryRes); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
@ -145,10 +149,12 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse {
return *resErr
}
claimRes := api.PerformClaimKeysResponse{}
keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
if err := keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: r.OneTimeKeys,
Timeout: r.GetTimeout(),
}, &claimRes)
}, &claimRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if claimRes.Error != nil {
util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys")
return jsonerror.InternalServerError()

View file

@ -17,6 +17,7 @@ package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -54,7 +55,9 @@ func PeekRoomByIDOrAlias(
}
// Ask the roomserver to perform the peek.
rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes)
if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil {
return util.ErrorResponse(err)
}
if peekRes.Error != nil {
return peekRes.Error.JSONResponse()
}
@ -89,7 +92,9 @@ func UnpeekRoomByID(
}
unpeekRes := roomserverAPI.PerformUnpeekResponse{}
rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes)
if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if unpeekRes.Error != nil {
return unpeekRes.Error.JSONResponse()
}

View file

@ -144,17 +144,23 @@ func Setup(
}
dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}",
httputil.MakeAuthAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateRoom(req, device, rsAPI)
httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateRoom(req, cfg, device, rsAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}",
httputil.MakeAuthAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateUser(req, device, rsAPI)
httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateUser(req, cfg, device, rsAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
// server notifications
if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
@ -929,12 +935,12 @@ func Setup(
return SearchUserDirectory(
req.Context(),
device,
userAPI,
rsAPI,
userDirectoryProvider,
cfg.Matrix.ServerName,
postContent.SearchString,
postContent.Limit,
federation,
cfg.Matrix.ServerName,
)
}),
).Methods(http.MethodPost, http.MethodOptions)

View file

@ -64,7 +64,9 @@ func UpgradeRoom(
}
upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{}
rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp)
if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if upgradeResp.Error != nil {
if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom {

View file

@ -18,10 +18,13 @@ import (
"context"
"database/sql"
"fmt"
"net/http"
"strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -34,12 +37,12 @@ type UserDirectoryResponse struct {
func SearchUserDirectory(
ctx context.Context,
device *userapi.Device,
userAPI userapi.ClientUserAPI,
rsAPI api.ClientRoomserverAPI,
provider userapi.QuerySearchProfilesAPI,
serverName gomatrixserverlib.ServerName,
searchString string,
limit int,
federation *gomatrixserverlib.FederationClient,
localServerName gomatrixserverlib.ServerName,
) util.JSONResponse {
if limit < 10 {
limit = 10
@ -51,59 +54,74 @@ func SearchUserDirectory(
Limited: false,
}
// First start searching local users.
userReq := &userapi.QuerySearchProfilesRequest{
SearchString: searchString,
Limit: limit,
// Get users we share a room with
knownUsersReq := &api.QueryKnownUsersRequest{
UserID: device.UserID,
Limit: limit,
}
userRes := &userapi.QuerySearchProfilesResponse{}
if err := provider.QuerySearchProfiles(ctx, userReq, userRes); err != nil {
return util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err))
knownUsersRes := &api.QueryKnownUsersResponse{}
if err := rsAPI.QueryKnownUsers(ctx, knownUsersReq, knownUsersRes); err != nil && err != sql.ErrNoRows {
return util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err))
}
for _, user := range userRes.Profiles {
knownUsersLoop:
for _, profile := range knownUsersRes.Users {
if len(results) == limit {
response.Limited = true
break
}
var userID string
if user.ServerName != "" {
userID = fmt.Sprintf("@%s:%s", user.Localpart, user.ServerName)
userID := profile.UserID
// get the full profile of the local user
localpart, serverName, _ := gomatrixserverlib.SplitID('@', userID)
if serverName == localServerName {
userReq := &userapi.QuerySearchProfilesRequest{
SearchString: localpart,
Limit: limit,
}
userRes := &userapi.QuerySearchProfilesResponse{}
if err := provider.QuerySearchProfiles(ctx, userReq, userRes); err != nil {
return util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err))
}
for _, p := range userRes.Profiles {
if strings.Contains(p.DisplayName, searchString) ||
strings.Contains(p.Localpart, searchString) {
profile.DisplayName = p.DisplayName
profile.AvatarURL = p.AvatarURL
results[userID] = profile
if len(results) == limit {
response.Limited = true
break knownUsersLoop
}
}
}
} else {
userID = fmt.Sprintf("@%s:%s", user.Localpart, serverName)
}
if _, ok := results[userID]; !ok {
results[userID] = authtypes.FullyQualifiedProfile{
UserID: userID,
DisplayName: user.DisplayName,
AvatarURL: user.AvatarURL,
// If the username already contains the search string, don't bother hitting federation.
// This will result in missing avatars and displaynames, but saves the federation roundtrip.
if strings.Contains(localpart, searchString) {
results[userID] = profile
if len(results) == limit {
response.Limited = true
break knownUsersLoop
}
continue
}
}
}
// Then, if we have enough room left in the response,
// start searching for known users from joined rooms.
if len(results) <= limit {
stateReq := &api.QueryKnownUsersRequest{
UserID: device.UserID,
SearchString: searchString,
Limit: limit - len(results),
}
stateRes := &api.QueryKnownUsersResponse{}
if err := rsAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil && err != sql.ErrNoRows {
return util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err))
}
for _, user := range stateRes.Users {
if len(results) == limit {
response.Limited = true
break
// TODO: We should probably cache/store this
fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "")
if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound {
continue
}
}
}
if _, ok := results[user.UserID]; !ok {
results[user.UserID] = user
if strings.Contains(fedProfile.DisplayName, searchString) {
profile.DisplayName = fedProfile.DisplayName
profile.AvatarURL = fedProfile.AvatarURL
results[userID] = profile
if len(results) == limit {
response.Limited = true
break knownUsersLoop
}
}
}
}

View file

@ -15,20 +15,26 @@
package main
import (
"context"
"bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"regexp"
"strings"
"time"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus"
"golang.org/x/term"
"github.com/matrix-org/dendrite/setup"
)
const usage = `Usage: %s
@ -46,8 +52,6 @@ Example:
# read password from stdin
%s --config dendrite.yaml -username alice -passwordstdin < my.pass
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
# reset password for a user, can be used with a combination above to read the password
%s --config dendrite.yaml -reset-password -username alice -password foobarbaz
Arguments:
@ -58,29 +62,34 @@ var (
password = flag.String("password", "", "The password to associate with the account")
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
pwdLess = flag.Bool("passwordless", false, "Create a passwordless account, e.g. if only an accesstoken is required")
isAdmin = flag.Bool("admin", false, "Create an admin account")
resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username")
resetPassword = flag.Bool("reset-password", false, "Deprecated")
serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
)
var cl = http.Client{
Timeout: time.Second * 10,
Transport: http.DefaultTransport,
}
func main() {
name := os.Args[0]
flag.Usage = func() {
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
flag.PrintDefaults()
}
cfg := setup.ParseFlags(true)
if *resetPassword {
logrus.Fatalf("The reset-password flag has been replaced by the POST /_dendrite/admin/resetPassword/{localpart} admin API.")
}
if *username == "" {
flag.Usage()
os.Exit(1)
}
if *pwdLess && *resetPassword {
logrus.Fatalf("Can not reset to an empty password, unable to login afterwards.")
}
if !validUsernameRegex.MatchString(*username) {
logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='")
os.Exit(1)
@ -90,67 +99,94 @@ func main() {
logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName))
}
var pass string
var err error
if !*pwdLess {
pass, err = getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
if err != nil {
logrus.Fatalln(err)
}
}
// avoid warning about open registration
cfg.ClientAPI.RegistrationDisabled = true
b := base.NewBaseDendrite(cfg, "")
defer b.Close() // nolint: errcheck
accountDB, err := storage.NewUserAPIDatabase(
b,
&cfg.UserAPI.AccountDatabase,
cfg.Global.ServerName,
cfg.UserAPI.BCryptCost,
cfg.UserAPI.OpenIDTokenLifetimeMS,
0, // TODO
cfg.Global.ServerNotices.LocalPart,
)
pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
if err != nil {
logrus.WithError(err).Fatalln("Failed to connect to the database")
logrus.Fatalln(err)
}
accType := api.AccountTypeUser
if *isAdmin {
accType = api.AccountTypeAdmin
}
available, err := accountDB.CheckAccountAvailability(context.Background(), *username)
if err != nil {
logrus.Fatalln("Unable check username existence.")
}
if *resetPassword {
if available {
logrus.Fatalln("Username could not be found.")
}
err = accountDB.SetPassword(context.Background(), *username, pass)
if err != nil {
logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error())
}
if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil {
logrus.Fatalf("Failed to remove all devices: %s", err.Error())
}
logrus.Infof("Updated password for user %s and invalidated all logins\n", *username)
return
}
if !available {
logrus.Fatalln("Username is already in use.")
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error())
}
logrus.Infoln("Created account", *username)
logrus.Infof("Created account: %s (AccessToken: %s)", *username, accessToken)
}
type sharedSecretRegistrationRequest struct {
User string `json:"username"`
Password string `json:"password"`
Nonce string `json:"nonce"`
MacStr string `json:"mac"`
Admin bool `json:"admin"`
}
func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accesToken string, err error) {
registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", serverURL)
nonceReq, err := http.NewRequest(http.MethodGet, registerURL, nil)
if err != nil {
return "", fmt.Errorf("unable to create http request: %w", err)
}
nonceResp, err := cl.Do(nonceReq)
if err != nil {
return "", fmt.Errorf("unable to get nonce: %w", err)
}
body, err := io.ReadAll(nonceResp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}
defer nonceResp.Body.Close() // nolint: errcheck
nonce := gjson.GetBytes(body, "nonce").Str
adminStr := "notadmin"
if admin {
adminStr = "admin"
}
reg := sharedSecretRegistrationRequest{
User: localpart,
Password: password,
Nonce: nonce,
Admin: admin,
}
macStr, err := getRegisterMac(sharedSecret, nonce, localpart, password, adminStr)
if err != nil {
return "", err
}
reg.MacStr = macStr
js, err := json.Marshal(reg)
if err != nil {
return "", fmt.Errorf("unable to marshal json: %w", err)
}
registerReq, err := http.NewRequest(http.MethodPost, registerURL, bytes.NewBuffer(js))
if err != nil {
return "", fmt.Errorf("unable to create http request: %w", err)
}
regResp, err := cl.Do(registerReq)
if err != nil {
return "", fmt.Errorf("unable to create account: %w", err)
}
defer regResp.Body.Close() // nolint: errcheck
if regResp.StatusCode < 200 || regResp.StatusCode >= 300 {
body, _ = io.ReadAll(regResp.Body)
return "", fmt.Errorf(gjson.GetBytes(body, "error").Str)
}
r, _ := io.ReadAll(regResp.Body)
return gjson.GetBytes(r, "access_token").Str, nil
}
func getRegisterMac(sharedSecret, nonce, localpart, password, adminStr string) (string, error) {
joined := strings.Join([]string{nonce, localpart, password, adminStr}, "\x00")
mac := hmac.New(sha1.New, []byte(sharedSecret))
_, err := mac.Write([]byte(joined))
if err != nil {
return "", fmt.Errorf("unable to construct mac: %w", err)
}
regMac := mac.Sum(nil)
return hex.EncodeToString(regMac), nil
}
func getPassword(password, pwdFile string, pwdStdin bool, r io.Reader) (string, error) {

View file

@ -14,9 +14,8 @@ User accounts can be created on a Dendrite instance in a number of ways.
The `create-account` tool is built in the `bin` folder when building Dendrite with
the `build.sh` script.
It uses the `dendrite.yaml` configuration file to connect to the Dendrite user database
and create the account entries directly. It can therefore be used even if Dendrite is not
running yet, as long as the database is up.
It uses the `dendrite.yaml` configuration file to connect to a running Dendrite instance and requires
shared secret registration to be enabled as explained below.
An example of using `create-account` to create a **normal account**:
@ -32,6 +31,13 @@ To create a new **admin account**, add the `-admin` flag:
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin
```
By default `create-account` uses `https://localhost:8448` to connect to Dendrite, this can be overwritten using
the `-url` flag:
```bash
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -url http://localhost:8008
```
An example of using `create-account` when running in **Docker**, having found the `CONTAINERNAME` from `docker ps`:
```bash

View file

@ -13,19 +13,32 @@ without warning.
More endpoints will be added in the future.
## `/_dendrite/admin/evacuateRoom/{roomID}`
## GET `/_dendrite/admin/evacuateRoom/{roomID}`
This endpoint will instruct Dendrite to part all local users from the given `roomID`
in the URL. It may take some time to complete. A JSON body will be returned containing
the user IDs of all affected users.
## `/_dendrite/admin/evacuateUser/{userID}`
## GET `/_dendrite/admin/evacuateUser/{userID}`
This endpoint will instruct Dendrite to part the given local `userID` in the URL from
all rooms which they are currently joined. A JSON body will be returned containing
the room IDs of all affected rooms.
## `/_synapse/admin/v1/register`
## POST `/_dendrite/admin/resetPassword/{localpart}`
Request body format:
```
{
"password": "new_password_here"
}
```
Reset the password of a local user. The `localpart` is the username only, i.e. if
the full user ID is `@alice:domain.com` then the local part is `alice`.
## GET `/_synapse/admin/v1/register`
Shared secret registration — please see the [user creation page](createusers) for
guidance on configuring and using this endpoint.

View file

@ -110,7 +110,7 @@ type FederationClientError struct {
Blacklisted bool
}
func (e *FederationClientError) Error() string {
func (e FederationClientError) Error() string {
return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted)
}

View file

@ -32,11 +32,12 @@ type fedRoomserverAPI struct {
}
// PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) error {
if f.inputRoomEvents == nil {
return
return nil
}
f.inputRoomEvents(ctx, req, res)
return nil
}
// keychange consumer calls this

View file

@ -10,7 +10,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP API
@ -48,7 +47,11 @@ func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client,
if httpClient == nil {
return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is <nil>")
}
return &httpFederationInternalAPI{federationSenderURL, httpClient, cache}, nil
return &httpFederationInternalAPI{
federationAPIURL: federationSenderURL,
httpClient: httpClient,
cache: cache,
}, nil
}
type httpFederationInternalAPI struct {
@ -63,11 +66,10 @@ func (h *httpFederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformLeaveRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformLeave", h.federationAPIURL+FederationAPIPerformLeaveRequestPath,
h.httpClient, ctx, request, response,
)
}
// Handle sending an invite to a remote server.
@ -76,11 +78,10 @@ func (h *httpFederationInternalAPI) PerformInvite(
request *api.PerformInviteRequest,
response *api.PerformInviteResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInviteRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformInviteRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformInvite", h.federationAPIURL+FederationAPIPerformInviteRequestPath,
h.httpClient, ctx, request, response,
)
}
// Handle starting a peek on a remote server.
@ -89,11 +90,10 @@ func (h *httpFederationInternalAPI) PerformOutboundPeek(
request *api.PerformOutboundPeekRequest,
response *api.PerformOutboundPeekResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOutboundPeekRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformOutboundPeekRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformOutboundPeek", h.federationAPIURL+FederationAPIPerformOutboundPeekRequestPath,
h.httpClient, ctx, request, response,
)
}
// QueryJoinedHostServerNamesInRoom implements FederationInternalAPI
@ -102,11 +102,10 @@ func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryJoinedHostServerNamesInRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryJoinedHostServerNamesInRoom", h.federationAPIURL+FederationAPIQueryJoinedHostServerNamesInRoomPath,
h.httpClient, ctx, request, response,
)
}
// Handle an instruction to make_join & send_join with a remote server.
@ -115,12 +114,10 @@ func (h *httpFederationInternalAPI) PerformJoin(
request *api.PerformJoinRequest,
response *api.PerformJoinResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformJoinRequestPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
if err := httputil.CallInternalRPCAPI(
"PerformJoinRequest", h.federationAPIURL+FederationAPIPerformJoinRequestPath,
h.httpClient, ctx, request, response,
); err != nil {
response.LastError = &gomatrix.HTTPError{
Message: err.Error(),
Code: 0,
@ -135,11 +132,10 @@ func (h *httpFederationInternalAPI) PerformDirectoryLookup(
request *api.PerformDirectoryLookupRequest,
response *api.PerformDirectoryLookupResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformDirectoryLookupRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformDirectoryLookup", h.federationAPIURL+FederationAPIPerformDirectoryLookupRequestPath,
h.httpClient, ctx, request, response,
)
}
// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to.
@ -148,101 +144,61 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU(
request *api.PerformBroadcastEDURequest,
response *api.PerformBroadcastEDUResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformBroadcastEDUPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformBroadcastEDU", h.federationAPIURL+FederationAPIPerformBroadcastEDUPath,
h.httpClient, ctx, request, response,
)
}
type getUserDevices struct {
S gomatrixserverlib.ServerName
UserID string
Res *gomatrixserverlib.RespUserDevices
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetUserDevices")
defer span.Finish()
var result gomatrixserverlib.RespUserDevices
request := getUserDevices{
S: s,
UserID: userID,
}
var response getUserDevices
apiURL := h.federationAPIURL + FederationAPIGetUserDevicesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError](
"GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient,
ctx, &getUserDevices{
S: s,
UserID: userID,
},
)
}
type claimKeys struct {
S gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string
Res *gomatrixserverlib.RespClaimKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "ClaimKeys")
defer span.Finish()
var result gomatrixserverlib.RespClaimKeys
request := claimKeys{
S: s,
OneTimeKeys: oneTimeKeys,
}
var response claimKeys
apiURL := h.federationAPIURL + FederationAPIClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError](
"ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient,
ctx, &claimKeys{
S: s,
OneTimeKeys: oneTimeKeys,
},
)
}
type queryKeys struct {
S gomatrixserverlib.ServerName
Keys map[string][]string
Res *gomatrixserverlib.RespQueryKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys")
defer span.Finish()
var result gomatrixserverlib.RespQueryKeys
request := queryKeys{
S: s,
Keys: keys,
}
var response queryKeys
apiURL := h.federationAPIURL + FederationAPIQueryKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError](
"QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient,
ctx, &queryKeys{
S: s,
Keys: keys,
},
)
}
type backfill struct {
@ -250,32 +206,20 @@ type backfill struct {
RoomID string
Limit int
EventIDs []string
Res *gomatrixserverlib.Transaction
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill")
defer span.Finish()
request := backfill{
S: s,
RoomID: roomID,
Limit: limit,
EventIDs: eventIDs,
}
var response backfill
apiURL := h.federationAPIURL + FederationAPIBackfillPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
if response.Err != nil {
return gomatrixserverlib.Transaction{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError](
"Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient,
ctx, &backfill{
S: s,
RoomID: roomID,
Limit: limit,
EventIDs: eventIDs,
},
)
}
type lookupState struct {
@ -283,63 +227,39 @@ type lookupState struct {
RoomID string
EventID string
RoomVersion gomatrixserverlib.RoomVersion
Res *gomatrixserverlib.RespState
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState")
defer span.Finish()
request := lookupState{
S: s,
RoomID: roomID,
EventID: eventID,
RoomVersion: roomVersion,
}
var response lookupState
apiURL := h.federationAPIURL + FederationAPILookupStatePath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespState{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespState{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError](
"LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient,
ctx, &lookupState{
S: s,
RoomID: roomID,
EventID: eventID,
RoomVersion: roomVersion,
},
)
}
type lookupStateIDs struct {
S gomatrixserverlib.ServerName
RoomID string
EventID string
Res *gomatrixserverlib.RespStateIDs
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs")
defer span.Finish()
request := lookupStateIDs{
S: s,
RoomID: roomID,
EventID: eventID,
}
var response lookupStateIDs
apiURL := h.federationAPIURL + FederationAPILookupStateIDsPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespStateIDs{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespStateIDs{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError](
"LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient,
ctx, &lookupStateIDs{
S: s,
RoomID: roomID,
EventID: eventID,
},
)
}
type lookupMissingEvents struct {
@ -347,64 +267,38 @@ type lookupMissingEvents struct {
RoomID string
Missing gomatrixserverlib.MissingEvents
RoomVersion gomatrixserverlib.RoomVersion
Res struct {
Events []gomatrixserverlib.RawJSON `json:"events"`
}
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupMissingEvents")
defer span.Finish()
request := lookupMissingEvents{
S: s,
RoomID: roomID,
Missing: missing,
RoomVersion: roomVersion,
}
apiURL := h.federationAPIURL + FederationAPILookupMissingEventsPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &request)
if err != nil {
return res, err
}
if request.Err != nil {
return res, request.Err
}
res.Events = request.Res.Events
return res, nil
return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError](
"LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient,
ctx, &lookupMissingEvents{
S: s,
RoomID: roomID,
Missing: missing,
RoomVersion: roomVersion,
},
)
}
type getEvent struct {
S gomatrixserverlib.ServerName
EventID string
Res *gomatrixserverlib.Transaction
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent")
defer span.Finish()
request := getEvent{
S: s,
EventID: eventID,
}
var response getEvent
apiURL := h.federationAPIURL + FederationAPIGetEventPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
if response.Err != nil {
return gomatrixserverlib.Transaction{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError](
"GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient,
ctx, &getEvent{
S: s,
EventID: eventID,
},
)
}
type getEventAuth struct {
@ -412,135 +306,86 @@ type getEventAuth struct {
RoomVersion gomatrixserverlib.RoomVersion
RoomID string
EventID string
Res *gomatrixserverlib.RespEventAuth
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (gomatrixserverlib.RespEventAuth, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEventAuth")
defer span.Finish()
request := getEventAuth{
S: s,
RoomVersion: roomVersion,
RoomID: roomID,
EventID: eventID,
}
var response getEventAuth
apiURL := h.federationAPIURL + FederationAPIGetEventAuthPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespEventAuth{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespEventAuth{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError](
"GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient,
ctx, &getEventAuth{
S: s,
RoomVersion: roomVersion,
RoomID: roomID,
EventID: eventID,
},
)
}
func (h *httpFederationInternalAPI) QueryServerKeys(
ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerKeys")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryServerKeysPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryServerKeys", h.federationAPIURL+FederationAPIQueryServerKeysPath,
h.httpClient, ctx, req, res,
)
}
type lookupServerKeys struct {
S gomatrixserverlib.ServerName
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
ServerKeys []gomatrixserverlib.ServerKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys")
defer span.Finish()
request := lookupServerKeys{
S: s,
KeyRequests: keyRequests,
}
var response lookupServerKeys
apiURL := h.federationAPIURL + FederationAPILookupServerKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
if response.Err != nil {
return []gomatrixserverlib.ServerKeys{}, response.Err
}
return response.ServerKeys, nil
return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError](
"LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient,
ctx, &lookupServerKeys{
S: s,
KeyRequests: keyRequests,
},
)
}
type eventRelationships struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion
Res gomatrixserverlib.MSC2836EventRelationshipsResponse
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships")
defer span.Finish()
request := eventRelationships{
S: s,
Req: r,
RoomVer: roomVersion,
}
var response eventRelationships
apiURL := h.federationAPIURL + FederationAPIEventRelationshipsPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient,
ctx, &eventRelationships{
S: s,
Req: r,
RoomVer: roomVersion,
},
)
}
type spacesReq struct {
S gomatrixserverlib.ServerName
SuggestedOnly bool
RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
defer span.Finish()
request := spacesReq{
S: dst,
SuggestedOnly: suggestedOnly,
RoomID: roomID,
}
var response spacesReq
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient,
ctx, &spacesReq{
S: dst,
SuggestedOnly: suggestedOnly,
RoomID: roomID,
},
)
}
func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing {
@ -614,11 +459,10 @@ func (h *httpFederationInternalAPI) InputPublicKeys(
request *api.InputPublicKeysRequest,
response *api.InputPublicKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIInputPublicKeyPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"InputPublicKey", h.federationAPIURL+FederationAPIInputPublicKeyPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) QueryPublicKeys(
@ -626,9 +470,8 @@ func (h *httpFederationInternalAPI) QueryPublicKeys(
request *api.QueryPublicKeysRequest,
response *api.QueryPublicKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryPublicKeyPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryPublicKeys", h.federationAPIURL+FederationAPIQueryPublicKeyPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -1,12 +1,14 @@
package inthttp
import (
"context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -15,372 +17,180 @@ import (
func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
FederationAPIQueryJoinedHostServerNamesInRoomPath,
httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryJoinedHostServerNamesInRoomRequest
var response api.QueryJoinedHostServerNamesInRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath,
httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformJoinRequest
var response api.PerformJoinResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
intAPI.PerformJoin(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformLeaveRequestPath,
httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformLeaveRequest
var response api.PerformLeaveResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom),
)
internalAPIMux.Handle(
FederationAPIPerformInviteRequestPath,
httputil.MakeInternalAPI("PerformInviteRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformInviteRequest
var response api.PerformInviteResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformInvite(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite),
)
internalAPIMux.Handle(
FederationAPIPerformLeaveRequestPath,
httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave),
)
internalAPIMux.Handle(
FederationAPIPerformDirectoryLookupRequestPath,
httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformDirectoryLookupRequest
var response api.PerformDirectoryLookupResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup),
)
internalAPIMux.Handle(
FederationAPIPerformBroadcastEDUPath,
httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse {
var request api.PerformBroadcastEDURequest
var response api.PerformBroadcastEDUResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
)
internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath,
httputil.MakeInternalRPCAPI(
"FederationAPIPerformJoinRequest",
func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error {
intAPI.PerformJoin(ctx, req, res)
return nil
},
),
)
internalAPIMux.Handle(
FederationAPIGetUserDevicesPath,
httputil.MakeInternalAPI("GetUserDevices", func(req *http.Request) util.JSONResponse {
var request getUserDevices
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetUserDevices(req.Context(), request.S, request.UserID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIGetUserDevices",
func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) {
res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIClaimKeysPath,
httputil.MakeInternalAPI("ClaimKeys", func(req *http.Request) util.JSONResponse {
var request claimKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.ClaimKeys(req.Context(), request.S, request.OneTimeKeys)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIClaimKeys",
func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) {
res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIQueryKeysPath,
httputil.MakeInternalAPI("QueryKeys", func(req *http.Request) util.JSONResponse {
var request queryKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.QueryKeys(req.Context(), request.S, request.Keys)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIQueryKeys",
func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) {
res, err := intAPI.QueryKeys(ctx, req.S, req.Keys)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIBackfillPath,
httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse {
var request backfill
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIBackfill",
func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPILookupStatePath,
httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse {
var request lookupState
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPILookupState",
func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) {
res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPILookupStateIDsPath,
httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse {
var request lookupStateIDs
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPILookupStateIDs",
func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) {
res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPILookupMissingEventsPath,
httputil.MakeInternalAPI("LookupMissingEvents", func(req *http.Request) util.JSONResponse {
var request lookupMissingEvents
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupMissingEvents(req.Context(), request.S, request.RoomID, request.Missing, request.RoomVersion)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
for _, event := range res.Events {
js, err := json.Marshal(event)
if err != nil {
return util.MessageResponse(http.StatusInternalServerError, err.Error())
}
request.Res.Events = append(request.Res.Events, js)
}
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPILookupMissingEvents",
func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) {
res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIGetEventPath,
httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse {
var request getEvent
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIGetEvent",
func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.GetEvent(ctx, req.S, req.EventID)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIGetEventAuthPath,
httputil.MakeInternalAPI("GetEventAuth", func(req *http.Request) util.JSONResponse {
var request getEventAuth
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetEventAuth(req.Context(), request.S, request.RoomVersion, request.RoomID, request.EventID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIGetEventAuth",
func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) {
res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIQueryServerKeysPath,
httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse {
var request api.QueryServerKeysRequest
var response api.QueryServerKeysResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryServerKeys(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys),
)
internalAPIMux.Handle(
FederationAPILookupServerKeysPath,
httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse {
var request lookupServerKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.ServerKeys = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPILookupServerKeys",
func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) {
res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIEventRelationshipsPath,
httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse {
var request eventRelationships
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIMSC2836EventRelationships",
func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPISpacesSummaryPath,
httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse {
var request spacesReq
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIMSC2946SpacesSummary",
func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly)
return &res, federationClientError(err)
},
),
)
// TODO: Look at this shape
internalAPIMux.Handle(FederationAPIQueryPublicKeyPath,
httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse {
httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryPublicKeysRequest{}
response := api.QueryPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -394,8 +204,10 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
// TODO: Look at this shape
internalAPIMux.Handle(FederationAPIInputPublicKeyPath,
httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse {
httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.InputPublicKeysRequest{}
response := api.InputPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -408,3 +220,18 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
}),
)
}
func federationClientError(err error) error {
switch ferr := err.(type) {
case nil:
return nil
case api.FederationClientError:
return &ferr
case *api.FederationClientError:
return ferr
default:
return &api.FederationClientError{
Err: err.Error(),
}
}
}

View file

@ -30,9 +30,11 @@ func GetUserDevices(
userID string,
) util.JSONResponse {
var res keyapi.QueryDeviceMessagesResponse
keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
UserID: userID,
}, &res)
}, &res); err != nil {
return util.ErrorResponse(err)
}
if res.Error != nil {
util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed")
return jsonerror.InternalServerError()
@ -47,7 +49,9 @@ func GetUserDevices(
for _, dev := range res.Devices {
sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID))
}
keyAPI.QuerySignatures(req.Context(), sigReq, sigRes)
if err := keyAPI.QuerySignatures(req.Context(), sigReq, sigRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
response := gomatrixserverlib.RespUserDevices{
UserID: userID,

View file

@ -392,7 +392,7 @@ func SendJoin(
// the room, so set SendAsServer to cfg.Matrix.ServerName
if !alreadyJoined {
var response api.InputRoomEventsResponse
rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{
{
Kind: api.KindNew,
@ -401,7 +401,9 @@ func SendJoin(
TransactionID: nil,
},
},
}, &response)
}, &response); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed")
if response.NotAllowed {

View file

@ -19,7 +19,7 @@ import (
"net/http"
"time"
"github.com/matrix-org/dendrite/clientapi/httputil"
clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
@ -61,9 +61,11 @@ func QueryDeviceKeys(
}
var queryRes api.QueryKeysResponse
keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
if err := keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
UserToDevices: qkr.DeviceKeys,
}, &queryRes)
}, &queryRes); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if queryRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys")
return jsonerror.InternalServerError()
@ -113,9 +115,11 @@ func ClaimOneTimeKeys(
}
var claimRes api.PerformClaimKeysResponse
keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
if err := keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: cor.OneTimeKeys,
}, &claimRes)
}, &claimRes); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if claimRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys")
return jsonerror.InternalServerError()
@ -184,7 +188,7 @@ func NotaryKeys(
) util.JSONResponse {
if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
return *reqErr
}
}

View file

@ -277,7 +277,7 @@ func SendLeave(
// We are responsible for notifying other servers that the user has left
// the room, so set SendAsServer to cfg.Matrix.ServerName
var response api.InputRoomEventsResponse
rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{
{
Kind: api.KindNew,
@ -286,7 +286,9 @@ func SendLeave(
TransactionID: nil,
},
},
}, &response)
}, &response); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed")

View file

@ -458,7 +458,9 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli
UserID: updatePayload.UserID,
}
uploadRes := &keyapi.PerformUploadDeviceKeysResponse{}
t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
return err
}
if uploadRes.Error != nil {
return uploadRes.Error
}

View file

@ -64,11 +64,12 @@ func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) {
) error {
t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...)
for _, ire := range request.InputRoomEvents {
fmt.Println("InputRoomEvents: ", ire.Event.EventID())
}
return nil
}
// Query the latest events and state for a room from the room server.

12
go.mod
View file

@ -21,12 +21,12 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771
github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.13
github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee
github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b
github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c
github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79
@ -34,7 +34,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.12.2
github.com/sirupsen/logrus v1.8.1
github.com/sirupsen/logrus v1.9.0
github.com/spruceid/siwe-go v0.2.0
github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.7.1
@ -44,7 +44,7 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.3
go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
golang.org/x/mobile v0.0.0-20220518205345-8578da9835fd
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e
@ -105,7 +105,7 @@ require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
golang.org/x/sys v0.0.0-20220731174439-a90be440212d // indirect
golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect
golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect
golang.org/x/tools v0.1.10 // indirect

24
go.sum
View file

@ -478,8 +478,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 h1:ZIPHFIPNDS9dmEbPEiJbNmyCGJtn9exfpLC7JOcn/bE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839 h1:QEFxKWH8PlEt3ZQKl31yJNAm8lvpNUwT51IMNTl9v1k=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY=
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -536,10 +536,10 @@ github.com/naoina/go-stringutil v0.1.0/go.mod h1:XJ2SJL9jCtBh+P9q5btrd/Ylo8XwT/h
github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416/go.mod h1:NBIhNtsFMo3G2szEBne+bO4gS192HuIYRqfvOWb4i1E=
github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI=
github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee h1:vAtoZ+LW6eIUjkCWWwO1DZ6o16UGrVOG+ot/AkwejO8=
github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee/go.mod h1:3Yg3ApyQxPlAs1KKHKV5pobV5VtZk+TtOiUJx/iqkkg=
github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b h1:CE9wSYLvwq8aC/0+6zH8lhhtZYvJ9p8PzwvZeYgdBc0=
github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c h1:U5qngWGZ7E/nQxz0544IpIEdKFUUaOJxQN2LHCYLGhg=
github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c/go.mod h1:+f++B/5jpr71JATt7b5KCX+G7bt43iWx1OYWGkpE/Kk=
github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995 h1:CUcSQR8jwa9//qNgN/t3tW53DObnTPQ/G/K+qnS7yRc=
github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -669,8 +669,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
@ -771,8 +771,8 @@ golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -964,8 +964,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=

View file

@ -146,7 +146,7 @@ func (c *RistrettoCostedCachePartition[K, V]) Set(key K, value V) {
}
type RistrettoCachePartition[K keyable, V any] struct {
cache *ristretto.Cache
cache *ristretto.Cache //nolint:all,unused
Prefix byte
Mutable bool
MaxAge time.Duration

View file

@ -19,19 +19,21 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/matrix-org/dendrite/userapi/api"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
// PostJSON performs a POST request with JSON on an internal HTTP API
func PostJSON(
// PostJSON performs a POST request with JSON on an internal HTTP API.
// The error will match the errtype if returned from the remote API, or
// will be a different type if there was a problem reaching the API.
func PostJSON[reqtype, restype any, errtype error](
ctx context.Context, span opentracing.Span, httpClient *http.Client,
apiURL string, request, response interface{},
apiURL string, request *reqtype, response *restype,
) error {
jsonBytes, err := json.Marshal(request)
if err != nil {
@ -69,17 +71,23 @@ func PostJSON(
if err != nil {
return err
}
if res.StatusCode != http.StatusOK {
var errorBody struct {
Message string `json:"message"`
}
if _, ok := response.(*api.PerformKeyBackupResponse); ok { // TODO: remove this, once cross-boundary errors are a thing
return nil
}
if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil {
return fmt.Errorf("internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message)
}
return fmt.Errorf("internal API: %d from %s", res.StatusCode, apiURL)
var body []byte
body, err = io.ReadAll(res.Body)
if err != nil {
return err
}
return json.NewDecoder(res.Body).Decode(response)
if res.StatusCode != http.StatusOK {
if len(body) == 0 {
return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL)
}
var reserr errtype
if err = json.Unmarshal(body, reserr); err != nil {
return fmt.Errorf("HTTP %d from %s", res.StatusCode, apiURL)
}
return reserr
}
if err = json.Unmarshal(body, response); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
return nil
}

View file

@ -25,6 +25,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
@ -83,6 +84,23 @@ func MakeAuthAPI(
return MakeExternalAPI(metricsName, h)
}
// MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be
// completed by a user that is a server administrator.
func MakeAdminAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
f func(*http.Request, *userapi.Device) util.JSONResponse,
) http.Handler {
return MakeAuthAPI(metricsName, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
return f(req, device)
})
}
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are called from the internet.
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {

View file

@ -0,0 +1,93 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httputil
import (
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
type InternalAPIError struct {
Type string
Message string
}
func (e InternalAPIError) Error() string {
return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message)
}
func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler {
return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse {
var request reqtype
var response restype
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := f(req.Context(), &request, &response); err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: &InternalAPIError{
Type: reflect.TypeOf(err).String(),
Message: fmt.Sprintf("%s", err),
},
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: &response,
}
})
}
func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler {
return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse {
var request reqtype
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
response, err := f(req.Context(), &request)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: err,
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: response,
}
})
}
func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Client, ctx context.Context, request *reqtype, response *restype) error {
span, ctx := opentracing.StartSpanFromContext(ctx, name)
defer span.Finish()
return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response)
}
func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, name)
defer span.Finish()
var response restype
return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response)
}

View file

@ -17,7 +17,7 @@ var build string
const (
VersionMajor = 0
VersionMinor = 9
VersionPatch = 1
VersionPatch = 2
VersionTag = "" // example: "rc1"
)

View file

@ -38,32 +38,32 @@ type KeyInternalAPI interface {
// API functions required by the clientapi
type ClientKeyAPI interface {
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse)
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// API functions required by the userapi
type UserKeyAPI interface {
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse)
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error
}
// API functions required by the syncapi
type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse)
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
}
type FederationKeyAPI interface {
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse)
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse)
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse)
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// KeyError is returned if there was a problem performing/querying the server

View file

@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos
}
// nolint:gocyclo
func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
// Find the keys to store.
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
toStore := types.CrossSigningKeyMap{}
@ -115,7 +115,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "Master key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
@ -131,7 +131,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "Self-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
@ -146,7 +146,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "User-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
@ -161,7 +161,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "No keys were supplied in the request",
IsMissingParam: true,
}
return
return nil
}
// We can't have a self-signing or user-signing key without a master
@ -174,7 +174,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
}
return
return nil
}
// If we still can't find a master key for the user then stop the upload.
@ -185,7 +185,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "No master key was found",
IsMissingParam: true,
}
return
return nil
}
}
@ -212,7 +212,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
}
}
if !changed {
return
return nil
}
// Store the keys.
@ -220,7 +220,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
}
return
return nil
}
// Now upload any signatures that were included with the keys.
@ -238,7 +238,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
}
return
return nil
}
}
}
@ -255,17 +255,18 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
update.SelfSigningKey = &ssk
}
if update.MasterKey == nil && update.SelfSigningKey == nil {
return
return nil
}
if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
return
return nil
}
return nil
}
func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) {
func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
// Before we do anything, we need the master and self-signing keys for this user.
// Then we can verify the signatures make sense.
queryReq := &api.QueryKeysRequest{
@ -276,7 +277,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
for userID := range req.Signatures {
queryReq.UserToDevices[userID] = []string{}
}
a.QueryKeys(ctx, queryReq, queryRes)
_ = a.QueryKeys(ctx, queryReq, queryRes)
selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
@ -322,14 +323,14 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
}
return
return nil
}
if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
}
return
return nil
}
// Finally, generate a notification that we updated the signatures.
@ -345,9 +346,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
return
return nil
}
}
return nil
}
func (a *KeyInternalAPI) processSelfSignatures(
@ -520,7 +522,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
}
}
func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) {
func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
for targetUserID, forTargetUser := range req.TargetIDs {
keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil && err != sql.ErrNoRows {
@ -559,7 +561,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
}
return
return nil
}
for sourceUserID, forSourceUser := range sigMap {
@ -581,4 +583,5 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
}
}
}
return nil
}

View file

@ -119,7 +119,7 @@ type DeviceListUpdaterDatabase interface {
}
type DeviceListUpdaterAPI interface {
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse)
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error
}
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
@ -421,7 +421,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
uploadReq.SelfSigningKey = *res.SelfSigningKey
}
}
u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
_ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
}
err = u.updateDeviceList(&res)
if err != nil {

View file

@ -113,8 +113,8 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys
type mockDeviceListUpdaterAPI struct {
}
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
return nil
}
type roundTripper struct {

View file

@ -48,18 +48,20 @@ func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
a.UserAPI = i
}
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
if err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
return nil
}
res.Offset = latest
res.UserIDs = userIDs
return nil
}
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
@ -67,9 +69,10 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
}
return nil
}
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
// wrap request map in a top-level by-domain map
@ -113,6 +116,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
}
return nil
}
func (a *KeyInternalAPI) claimRemoteKeys(
@ -172,32 +176,34 @@ func (a *KeyInternalAPI) claimRemoteKeys(
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
}
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) {
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
}
}
return nil
}
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
}
return
return nil
}
res.Count = *count
return nil
}
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
}
return
return nil
}
maxStreamID := int64(0)
for _, m := range msgs {
@ -215,10 +221,11 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
}
res.Devices = result
res.StreamID = maxStreamID
return nil
}
// nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@ -244,7 +251,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
}
return
return nil
}
// pull out display names after we have the keys so we handle wildcards correctly
@ -318,7 +325,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
return nil
}
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
@ -344,7 +351,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
return nil
}
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
@ -372,6 +379,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
}
}
return nil
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@ -68,168 +67,108 @@ func (h *httpKeyInternalAPI) PerformClaimKeys(
ctx context.Context,
request *api.PerformClaimKeysRequest,
response *api.PerformClaimKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys")
defer span.Finish()
apiURL := h.apiURL + PerformClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformClaimKeys", h.apiURL+PerformClaimKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) PerformDeleteKeys(
ctx context.Context,
request *api.PerformDeleteKeysRequest,
response *api.PerformDeleteKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys")
defer span.Finish()
apiURL := h.apiURL + PerformClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformDeleteKeys", h.apiURL+PerformDeleteKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) PerformUploadKeys(
ctx context.Context,
request *api.PerformUploadKeysRequest,
response *api.PerformUploadKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadKeys")
defer span.Finish()
apiURL := h.apiURL + PerformUploadKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformUploadKeys", h.apiURL+PerformUploadKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QueryKeys(
ctx context.Context,
request *api.QueryKeysRequest,
response *api.QueryKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys")
defer span.Finish()
apiURL := h.apiURL + QueryKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QueryKeys", h.apiURL+QueryKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QueryOneTimeKeys(
ctx context.Context,
request *api.QueryOneTimeKeysRequest,
response *api.QueryOneTimeKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys")
defer span.Finish()
apiURL := h.apiURL + QueryOneTimeKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QueryOneTimeKeys", h.apiURL+QueryOneTimeKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QueryDeviceMessages(
ctx context.Context,
request *api.QueryDeviceMessagesRequest,
response *api.QueryDeviceMessagesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceMessages")
defer span.Finish()
apiURL := h.apiURL + QueryDeviceMessagesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QueryDeviceMessages", h.apiURL+QueryDeviceMessagesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QueryKeyChanges(
ctx context.Context,
request *api.QueryKeyChangesRequest,
response *api.QueryKeyChangesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges")
defer span.Finish()
apiURL := h.apiURL + QueryKeyChangesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QueryKeyChanges", h.apiURL+QueryKeyChangesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) PerformUploadDeviceKeys(
ctx context.Context,
request *api.PerformUploadDeviceKeysRequest,
response *api.PerformUploadDeviceKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceKeys")
defer span.Finish()
apiURL := h.apiURL + PerformUploadDeviceKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformUploadDeviceKeys", h.apiURL+PerformUploadDeviceKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures(
ctx context.Context,
request *api.PerformUploadDeviceSignaturesRequest,
response *api.PerformUploadDeviceSignaturesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceSignatures")
defer span.Finish()
apiURL := h.apiURL + PerformUploadDeviceSignaturesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformUploadDeviceSignatures", h.apiURL+PerformUploadDeviceSignaturesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QuerySignatures(
ctx context.Context,
request *api.QuerySignaturesRequest,
response *api.QuerySignaturesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySignatures")
defer span.Finish()
apiURL := h.apiURL + QuerySignaturesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QuerySignatures", h.apiURL+QuerySignaturesPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -15,124 +15,59 @@
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/util"
)
func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
internalAPIMux.Handle(PerformClaimKeysPath,
httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformClaimKeysRequest{}
response := api.PerformClaimKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformClaimKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformClaimKeysPath,
httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys),
)
internalAPIMux.Handle(PerformDeleteKeysPath,
httputil.MakeInternalAPI("performDeleteKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformDeleteKeysRequest{}
response := api.PerformDeleteKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformDeleteKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformDeleteKeysPath,
httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys),
)
internalAPIMux.Handle(PerformUploadKeysPath,
httputil.MakeInternalAPI("performUploadKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadKeysRequest{}
response := api.PerformUploadKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformUploadKeysPath,
httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys),
)
internalAPIMux.Handle(PerformUploadDeviceKeysPath,
httputil.MakeInternalAPI("performUploadDeviceKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadDeviceKeysRequest{}
response := api.PerformUploadDeviceKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadDeviceKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformUploadDeviceKeysPath,
httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys),
)
internalAPIMux.Handle(PerformUploadDeviceSignaturesPath,
httputil.MakeInternalAPI("performUploadDeviceSignatures", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadDeviceSignaturesRequest{}
response := api.PerformUploadDeviceSignaturesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadDeviceSignatures(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformUploadDeviceSignaturesPath,
httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures),
)
internalAPIMux.Handle(QueryKeysPath,
httputil.MakeInternalAPI("queryKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryKeysRequest{}
response := api.QueryKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryKeysPath,
httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys),
)
internalAPIMux.Handle(QueryOneTimeKeysPath,
httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryOneTimeKeysRequest{}
response := api.QueryOneTimeKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryOneTimeKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryOneTimeKeysPath,
httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys),
)
internalAPIMux.Handle(QueryDeviceMessagesPath,
httputil.MakeInternalAPI("queryDeviceMessages", func(req *http.Request) util.JSONResponse {
request := api.QueryDeviceMessagesRequest{}
response := api.QueryDeviceMessagesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryDeviceMessages(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryDeviceMessagesPath,
httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages),
)
internalAPIMux.Handle(QueryKeyChangesPath,
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
request := api.QueryKeyChangesRequest{}
response := api.QueryKeyChangesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeyChanges(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryKeyChangesPath,
httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges),
)
internalAPIMux.Handle(QuerySignaturesPath,
httputil.MakeInternalAPI("querySignatures", func(req *http.Request) util.JSONResponse {
request := api.QuerySignaturesRequest{}
response := api.QuerySignaturesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QuerySignatures(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QuerySignaturesPath,
httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures),
)
}

View file

@ -40,7 +40,7 @@ type InputRoomEventsAPI interface {
ctx context.Context,
req *InputRoomEventsRequest,
res *InputRoomEventsResponse,
)
) error
}
// Query the latest events and state for a room from the room server.
@ -97,6 +97,14 @@ type SyncRoomserverAPI interface {
req *PerformBackfillRequest,
res *PerformBackfillResponse,
) error
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to a slice of gomatrixserverlib.HeaderedEvent.
QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error
}
type AppserviceRoomserverAPI interface {
@ -139,15 +147,15 @@ type ClientRoomserverAPI interface {
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
// PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse)
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse)
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse)
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse)
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse)
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse)
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse)
PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error
// PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error
@ -158,7 +166,7 @@ type UserRoomserverAPI interface {
QueryLatestEventsAndStateAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse)
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
}
type FederationRoomserverAPI interface {

View file

@ -35,9 +35,10 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents(
ctx context.Context,
req *InputRoomEventsRequest,
res *InputRoomEventsResponse,
) {
t.Impl.InputRoomEvents(ctx, req, res)
util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.InputRoomEvents(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformInvite(
@ -45,44 +46,49 @@ func (t *RoomserverInternalAPITrace) PerformInvite(
req *PerformInviteRequest,
res *PerformInviteResponse,
) error {
util.GetLogger(ctx).Infof("PerformInvite req=%+v res=%+v", js(req), js(res))
return t.Impl.PerformInvite(ctx, req, res)
err := t.Impl.PerformInvite(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformInvite req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformPeek(
ctx context.Context,
req *PerformPeekRequest,
res *PerformPeekResponse,
) {
t.Impl.PerformPeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPeek req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformPeek(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformPeek req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformUnpeek(
ctx context.Context,
req *PerformUnpeekRequest,
res *PerformUnpeekResponse,
) {
t.Impl.PerformUnpeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformUnpeek(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformRoomUpgrade(
ctx context.Context,
req *PerformRoomUpgradeRequest,
res *PerformRoomUpgradeResponse,
) {
t.Impl.PerformRoomUpgrade(ctx, req, res)
util.GetLogger(ctx).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformRoomUpgrade(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformJoin(
ctx context.Context,
req *PerformJoinRequest,
res *PerformJoinResponse,
) {
t.Impl.PerformJoin(ctx, req, res)
util.GetLogger(ctx).Infof("PerformJoin req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformJoin(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformLeave(
@ -99,27 +105,30 @@ func (t *RoomserverInternalAPITrace) PerformPublish(
ctx context.Context,
req *PerformPublishRequest,
res *PerformPublishResponse,
) {
t.Impl.PerformPublish(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformPublish(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom(
ctx context.Context,
req *PerformAdminEvacuateRoomRequest,
res *PerformAdminEvacuateRoomResponse,
) {
t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser(
ctx context.Context,
req *PerformAdminEvacuateUserRequest,
res *PerformAdminEvacuateUserResponse,
) {
t.Impl.PerformAdminEvacuateUser(ctx, req, res)
util.GetLogger(ctx).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformAdminEvacuateUser(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformInboundPeek(
@ -128,7 +137,7 @@ func (t *RoomserverInternalAPITrace) PerformInboundPeek(
res *PerformInboundPeekResponse,
) error {
err := t.Impl.PerformInboundPeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res))
util.GetLogger(ctx).WithError(err).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res))
return err
}
@ -373,6 +382,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed(
return err
}
func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error {
err := t.Impl.QueryMembershipAtEvent(ctx, request, response)
util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response))
return err
}
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {

View file

@ -427,3 +427,17 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
}
return nil
}
// QueryMembershipAtEventRequest requests the membership events for a user
// for a list of eventIDs.
type QueryMembershipAtEventRequest struct {
RoomID string
EventIDs []string
UserID string
}
// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
type QueryMembershipAtEventResponse struct {
// Memberships is a map from eventID to a list of events (if any).
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
}

View file

@ -90,7 +90,9 @@ func SendInputRoomEvents(
Asynchronous: async,
}
var response InputRoomEventsResponse
rsAPI.InputRoomEvents(ctx, &request, &response)
if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil {
return err
}
return response.Err()
}

View file

@ -208,6 +208,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
}
func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) {

View file

@ -337,18 +337,18 @@ func (r *Inputer) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) {
) error {
// Queue up the event into the roomserver.
replySub, err := r.queueInputRoomEvents(ctx, request)
if err != nil {
response.ErrMsg = err.Error()
return
return nil
}
// If we aren't waiting for synchronous responses then we can
// give up here, there is nothing further to do.
if replySub == nil {
return
return nil
}
// Otherwise, we'll want to sit and wait for the responses
@ -360,12 +360,14 @@ func (r *Inputer) InputRoomEvents(
msg, err := replySub.NextMsgWithContext(ctx)
if err != nil {
response.ErrMsg = err.Error()
return
return nil
}
if len(msg.Data) > 0 {
response.ErrMsg = string(msg.Data)
}
}
return nil
}
var roomserverInputBackpressure = prometheus.NewGaugeVec(

View file

@ -299,7 +299,7 @@ func (r *Inputer) processRoomEvent(
// allowed at the time, and also to get the history visibility. We won't
// bother doing this if the event was already rejected as it just ends up
// burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if rejectionErr == nil && !isRejected && !softfail {
var err error
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
@ -429,7 +429,7 @@ func (r *Inputer) processStateBefore(
input *api.InputRoomEvent,
missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
historyVisibility = gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive.
historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared.
event := input.Event.Unwrap()
isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("")
var stateBeforeEvent []*gomatrixserverlib.Event

View file

@ -43,21 +43,21 @@ func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
) {
) error {
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return
return nil
}
if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{
Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room %s not found", req.RoomID),
}
return
return nil
}
memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
@ -66,7 +66,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err),
}
return
return nil
}
memberEvents, err := r.DB.Events(ctx, memberNIDs)
@ -75,7 +75,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.Events: %s", err),
}
return
return nil
}
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
@ -89,7 +89,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err),
}
return
return nil
}
prevEvents := latestRes.LatestEvents
@ -104,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Unmarshal: %s", err),
}
return
return nil
}
memberContent.Membership = gomatrixserverlib.Leave
@ -122,7 +122,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Marshal: %s", err),
}
return
return nil
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
@ -131,7 +131,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return
return nil
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes)
@ -140,7 +140,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return
return nil
}
inputEvents = append(inputEvents, api.InputRoomEvent{
@ -160,28 +160,28 @@ func (r *Admin) PerformAdminEvacuateRoom(
Asynchronous: true,
}
inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
}
func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context,
req *api.PerformAdminEvacuateUserRequest,
res *api.PerformAdminEvacuateUserResponse,
) {
) error {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed user ID: %s", err),
}
return
return nil
}
if domain != r.Cfg.Matrix.ServerName {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: "Can only evacuate local users using this endpoint",
}
return
return nil
}
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join)
@ -190,7 +190,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
return
return nil
}
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite)
@ -199,7 +199,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
return
return nil
}
for _, roomID := range append(roomIDs, inviteRoomIDs...) {
@ -214,7 +214,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err),
}
return
return nil
}
if len(outputEvents) == 0 {
continue
@ -224,9 +224,10 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err),
}
return
return nil
}
res.Affected = append(res.Affected, roomID)
}
return nil
}

View file

@ -241,7 +241,9 @@ func (r *Inviter) PerformInvite(
},
}
inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil {
return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
}
if err = inputRes.Err(); err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),

View file

@ -52,7 +52,7 @@ func (r *Joiner) PerformJoin(
ctx context.Context,
req *rsAPI.PerformJoinRequest,
res *rsAPI.PerformJoinResponse,
) {
) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomIDOrAlias,
"user_id": req.UserID,
@ -71,11 +71,12 @@ func (r *Joiner) PerformJoin(
Msg: err.Error(),
}
}
return
return nil
}
logger.Info("User joined room successfully")
res.RoomID = roomID
res.JoinedVia = joinedVia
return nil
}
func (r *Joiner) performJoin(
@ -291,7 +292,12 @@ func (r *Joiner) performJoinRoomByID(
},
}
inputRes := rsAPI.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNoOperation,
Msg: fmt.Sprintf("InputRoomEvents failed: %s", err),
}
}
if err = inputRes.Err(); err != nil {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNotAllowed,

View file

@ -186,7 +186,9 @@ func (r *Leaver) performLeaveRoomByID(
},
}
inputRes := api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
}
if err = inputRes.Err(); err != nil {
return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
}

View file

@ -44,7 +44,7 @@ func (r *Peeker) PerformPeek(
ctx context.Context,
req *api.PerformPeekRequest,
res *api.PerformPeekResponse,
) {
) error {
roomID, err := r.performPeek(ctx, req)
if err != nil {
perr, ok := err.(*api.PerformError)
@ -57,6 +57,7 @@ func (r *Peeker) PerformPeek(
}
}
res.RoomID = roomID
return nil
}
func (r *Peeker) performPeek(

View file

@ -29,11 +29,12 @@ func (r *Publisher) PerformPublish(
ctx context.Context,
req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
) {
) error {
err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public")
if err != nil {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
return nil
}

View file

@ -41,7 +41,7 @@ func (r *Unpeeker) PerformUnpeek(
ctx context.Context,
req *api.PerformUnpeekRequest,
res *api.PerformUnpeekResponse,
) {
) error {
if err := r.performUnpeek(ctx, req); err != nil {
perr, ok := err.(*api.PerformError)
if ok {
@ -52,6 +52,7 @@ func (r *Unpeeker) PerformUnpeek(
}
}
}
return nil
}
func (r *Unpeeker) performUnpeek(

View file

@ -45,12 +45,13 @@ func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context,
req *api.PerformRoomUpgradeRequest,
res *api.PerformRoomUpgradeResponse,
) {
) error {
res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req)
if res.Error != nil {
res.NewRoomID = ""
logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed")
}
return nil
}
func (r *Upgrader) performRoomUpgrade(
@ -286,22 +287,24 @@ func publishNewRoomAndUnpublishOldRoom(
) {
// expose this room in the published room list
var pubNewRoomRes api.PerformPublishResponse
URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: newRoomID,
Visibility: "public",
}, &pubNewRoomRes)
if pubNewRoomRes.Error != nil {
}, &pubNewRoomRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if pubNewRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public")
}
var unpubOldRoomRes api.PerformPublishResponse
// remove the old room from the published room list
URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: oldRoomID,
Visibility: "private",
}, &unpubOldRoomRes)
if unpubOldRoomRes.Error != nil {
}, &unpubOldRoomRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if unpubOldRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private")
}

View file

@ -21,6 +21,10 @@ import (
"errors"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/acls"
@ -30,9 +34,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
type Queryer struct {
@ -204,6 +205,54 @@ func (r *Queryer) QueryMembershipForUser(
return err
}
func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context,
request *api.QueryMembershipAtEventRequest,
response *api.QueryMembershipAtEventResponse,
) error {
response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err)
}
if info == nil {
return fmt.Errorf("no roomInfo found")
}
// get the users stateKeyNID
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
if err != nil {
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
}
if _, ok := stateKeyNIDs[request.UserID]; !ok {
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
}
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID])
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
}
for _, eventID := range request.EventIDs {
stateEntry := stateEntries[eventID]
memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
}
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
for i := range memberships {
ev := memberships[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
res = append(res, ev.Headered(info.RoomVersion))
}
}
response.Memberships[eventID] = res
}
return nil
}
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipsForRoom(
ctx context.Context,
@ -684,7 +733,7 @@ func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForU
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit)
if err != nil {
if err != nil && err != sql.ErrNoRows {
return err
}
for _, user := range users {

View file

@ -3,18 +3,16 @@ package inthttp
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/matrix-org/gomatrixserverlib"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsInputAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
)
const (
@ -63,6 +61,7 @@ const (
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
)
type httpRoomserverInternalAPI struct {
@ -106,11 +105,10 @@ func (h *httpRoomserverInternalAPI) SetRoomAlias(
request *api.SetRoomAliasRequest,
response *api.SetRoomAliasResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverSetRoomAliasPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"SetRoomAlias", h.roomserverURL+RoomserverSetRoomAliasPath,
h.httpClient, ctx, request, response,
)
}
// GetRoomIDForAlias implements RoomserverAliasAPI
@ -119,11 +117,10 @@ func (h *httpRoomserverInternalAPI) GetRoomIDForAlias(
request *api.GetRoomIDForAliasRequest,
response *api.GetRoomIDForAliasResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"GetRoomIDForAlias", h.roomserverURL+RoomserverGetRoomIDForAliasPath,
h.httpClient, ctx, request, response,
)
}
// GetAliasesForRoomID implements RoomserverAliasAPI
@ -132,11 +129,10 @@ func (h *httpRoomserverInternalAPI) GetAliasesForRoomID(
request *api.GetAliasesForRoomIDRequest,
response *api.GetAliasesForRoomIDResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"GetAliasesForRoomID", h.roomserverURL+RoomserverGetAliasesForRoomIDPath,
h.httpClient, ctx, request, response,
)
}
// RemoveRoomAlias implements RoomserverAliasAPI
@ -145,11 +141,10 @@ func (h *httpRoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"RemoveRoomAlias", h.roomserverURL+RoomserverRemoveRoomAliasPath,
h.httpClient, ctx, request, response,
)
}
// InputRoomEvents implements RoomserverInputAPI
@ -157,15 +152,14 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverInputRoomEventsPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
) error {
if err := httputil.CallInternalRPCAPI(
"InputRoomEvents", h.roomserverURL+RoomserverInputRoomEventsPath,
h.httpClient, ctx, request, response,
); err != nil {
response.ErrMsg = err.Error()
}
return nil
}
func (h *httpRoomserverInternalAPI) PerformInvite(
@ -173,45 +167,32 @@ func (h *httpRoomserverInternalAPI) PerformInvite(
request *api.PerformInviteRequest,
response *api.PerformInviteResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInvite")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformInvitePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformInvite", h.roomserverURL+RoomserverPerformInvitePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformJoin(
ctx context.Context,
request *api.PerformJoinRequest,
response *api.PerformJoinResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoin")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformJoinPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformJoin", h.roomserverURL+RoomserverPerformJoinPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformPeek(
ctx context.Context,
request *api.PerformPeekRequest,
response *api.PerformPeekResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPeek")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformPeekPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformPeek", h.roomserverURL+RoomserverPerformPeekPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformInboundPeek(
@ -219,45 +200,32 @@ func (h *httpRoomserverInternalAPI) PerformInboundPeek(
request *api.PerformInboundPeekRequest,
response *api.PerformInboundPeekResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInboundPeek")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformInboundPeekPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformInboundPeek", h.roomserverURL+RoomserverPerformInboundPeekPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformUnpeek(
ctx context.Context,
request *api.PerformUnpeekRequest,
response *api.PerformUnpeekResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUnpeek")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformUnpeekPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformUnpeek", h.roomserverURL+RoomserverPerformUnpeekPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformRoomUpgrade(
ctx context.Context,
request *api.PerformRoomUpgradeRequest,
response *api.PerformRoomUpgradeResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformRoomUpgrade")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformRoomUpgradePath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformRoomUpgrade", h.roomserverURL+RoomserverPerformRoomUpgradePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformLeave(
@ -265,62 +233,43 @@ func (h *httpRoomserverInternalAPI) PerformLeave(
request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeave")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformLeavePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformLeave", h.roomserverURL+RoomserverPerformLeavePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformPublish(
ctx context.Context,
req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPublish")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformPublishPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
request *api.PerformPublishRequest,
response *api.PerformPublishResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformPublish", h.roomserverURL+RoomserverPerformPublishPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
request *api.PerformAdminEvacuateRoomRequest,
response *api.PerformAdminEvacuateRoomResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformAdminEvacuateRoom", h.roomserverURL+RoomserverPerformAdminEvacuateRoomPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser(
ctx context.Context,
req *api.PerformAdminEvacuateUserRequest,
res *api.PerformAdminEvacuateUserResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateUser")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateUserPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
request *api.PerformAdminEvacuateUserRequest,
response *api.PerformAdminEvacuateUserResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformAdminEvacuateUser", h.roomserverURL+RoomserverPerformAdminEvacuateUserPath,
h.httpClient, ctx, request, response,
)
}
// QueryLatestEventsAndState implements RoomserverQueryAPI
@ -329,11 +278,10 @@ func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryLatestEventsAndState", h.roomserverURL+RoomserverQueryLatestEventsAndStatePath,
h.httpClient, ctx, request, response,
)
}
// QueryStateAfterEvents implements RoomserverQueryAPI
@ -342,11 +290,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents(
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryStateAfterEvents", h.roomserverURL+RoomserverQueryStateAfterEventsPath,
h.httpClient, ctx, request, response,
)
}
// QueryEventsByID implements RoomserverQueryAPI
@ -355,11 +302,10 @@ func (h *httpRoomserverInternalAPI) QueryEventsByID(
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryEventsByID", h.roomserverURL+RoomserverQueryEventsByIDPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryPublishedRooms(
@ -367,11 +313,10 @@ func (h *httpRoomserverInternalAPI) QueryPublishedRooms(
request *api.QueryPublishedRoomsRequest,
response *api.QueryPublishedRoomsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublishedRooms")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryPublishedRoomsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryPublishedRooms", h.roomserverURL+RoomserverQueryPublishedRoomsPath,
h.httpClient, ctx, request, response,
)
}
// QueryMembershipForUser implements RoomserverQueryAPI
@ -380,11 +325,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipForUser(
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryMembershipForUser", h.roomserverURL+RoomserverQueryMembershipForUserPath,
h.httpClient, ctx, request, response,
)
}
// QueryMembershipsForRoom implements RoomserverQueryAPI
@ -393,11 +337,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom(
request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryMembershipsForRoom", h.roomserverURL+RoomserverQueryMembershipsForRoomPath,
h.httpClient, ctx, request, response,
)
}
// QueryMembershipsForRoom implements RoomserverQueryAPI
@ -406,11 +349,10 @@ func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom(
request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryServerJoinedToRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryServerJoinedToRoom", h.roomserverURL+RoomserverQueryServerJoinedToRoomPath,
h.httpClient, ctx, request, response,
)
}
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
@ -419,11 +361,10 @@ func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryServerAllowedToSeeEvent", h.roomserverURL+RoomserverQueryServerAllowedToSeeEventPath,
h.httpClient, ctx, request, response,
)
}
// QueryMissingEvents implements RoomServerQueryAPI
@ -432,11 +373,10 @@ func (h *httpRoomserverInternalAPI) QueryMissingEvents(
request *api.QueryMissingEventsRequest,
response *api.QueryMissingEventsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryMissingEvents", h.roomserverURL+RoomserverQueryMissingEventsPath,
h.httpClient, ctx, request, response,
)
}
// QueryStateAndAuthChain implements RoomserverQueryAPI
@ -445,11 +385,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain(
request *api.QueryStateAndAuthChainRequest,
response *api.QueryStateAndAuthChainResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryStateAndAuthChain", h.roomserverURL+RoomserverQueryStateAndAuthChainPath,
h.httpClient, ctx, request, response,
)
}
// PerformBackfill implements RoomServerQueryAPI
@ -458,11 +397,10 @@ func (h *httpRoomserverInternalAPI) PerformBackfill(
request *api.PerformBackfillRequest,
response *api.PerformBackfillResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBackfill")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformBackfillPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformBackfill", h.roomserverURL+RoomserverPerformBackfillPath,
h.httpClient, ctx, request, response,
)
}
// QueryRoomVersionCapabilities implements RoomServerQueryAPI
@ -471,11 +409,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities(
request *api.QueryRoomVersionCapabilitiesRequest,
response *api.QueryRoomVersionCapabilitiesResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryRoomVersionCapabilities", h.roomserverURL+RoomserverQueryRoomVersionCapabilitiesPath,
h.httpClient, ctx, request, response,
)
}
// QueryRoomVersionForRoom implements RoomServerQueryAPI
@ -488,12 +425,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom(
response.RoomVersion = roomVersion
return nil
}
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
err := httputil.CallInternalRPCAPI(
"QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath,
h.httpClient, ctx, request, response,
)
if err == nil {
h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
}
@ -505,11 +440,10 @@ func (h *httpRoomserverInternalAPI) QueryCurrentState(
request *api.QueryCurrentStateRequest,
response *api.QueryCurrentStateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryCurrentState", h.roomserverURL+RoomserverQueryCurrentStatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
@ -517,11 +451,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
request *api.QueryRoomsForUserRequest,
response *api.QueryRoomsForUserResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryRoomsForUser", h.roomserverURL+RoomserverQueryRoomsForUserPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
@ -529,68 +462,82 @@ func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
request *api.QueryBulkStateContentRequest,
response *api.QueryBulkStateContentResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryBulkStateContent", h.roomserverURL+RoomserverQueryBulkStateContentPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QuerySharedUsers(
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
ctx context.Context,
request *api.QuerySharedUsersRequest,
response *api.QuerySharedUsersResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QuerySharedUsers", h.roomserverURL+RoomserverQuerySharedUsersPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryKnownUsers(
ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
ctx context.Context,
request *api.QueryKnownUsersRequest,
response *api.QueryKnownUsersResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryKnownUsers", h.roomserverURL+RoomserverQueryKnownUsersPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryAuthChain(
ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse,
ctx context.Context,
request *api.QueryAuthChainRequest,
response *api.QueryAuthChainResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAuthChain")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryAuthChainPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryAuthChain", h.roomserverURL+RoomserverQueryAuthChainPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
ctx context.Context,
request *api.QueryServerBannedFromRoomRequest,
response *api.QueryServerBannedFromRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryServerBannedFromRoom", h.roomserverURL+RoomserverQueryServerBannedFromRoomPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed(
ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse,
ctx context.Context,
request *api.QueryRestrictedJoinAllowedRequest,
response *api.QueryRestrictedJoinAllowedResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRestrictedJoinAllowed")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRestrictedJoinAllowed
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryRestrictedJoinAllowed", h.roomserverURL+RoomserverQueryRestrictedJoinAllowed,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformForgetPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpRoomserverInternalAPI) PerformForget(
ctx context.Context,
request *api.PerformForgetRequest,
response *api.PerformForgetResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformForget", h.roomserverURL+RoomserverPerformForgetPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse) error {
return httputil.CallInternalRPCAPI(
"QueryMembershiptAtEvent", h.roomserverURL+RoomserverQueryMembershipAtEventPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -1,499 +1,201 @@
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util"
)
// AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux.
// nolint: gocyclo
func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(RoomserverInputRoomEventsPath,
httputil.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse {
var request api.InputRoomEventsRequest
var response api.InputRoomEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.InputRoomEvents(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverInputRoomEventsPath,
httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents),
)
internalAPIMux.Handle(RoomserverPerformInvitePath,
httputil.MakeInternalAPI("performInvite", func(req *http.Request) util.JSONResponse {
var request api.PerformInviteRequest
var response api.PerformInviteResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformInvite(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformInvitePath,
httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite),
)
internalAPIMux.Handle(RoomserverPerformJoinPath,
httputil.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse {
var request api.PerformJoinRequest
var response api.PerformJoinResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformJoin(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformJoinPath,
httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin),
)
internalAPIMux.Handle(RoomserverPerformLeavePath,
httputil.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse {
var request api.PerformLeaveRequest
var response api.PerformLeaveResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformLeave(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformLeavePath,
httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave),
)
internalAPIMux.Handle(RoomserverPerformPeekPath,
httputil.MakeInternalAPI("performPeek", func(req *http.Request) util.JSONResponse {
var request api.PerformPeekRequest
var response api.PerformPeekResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformPeek(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformPeekPath,
httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek),
)
internalAPIMux.Handle(RoomserverPerformInboundPeekPath,
httputil.MakeInternalAPI("performInboundPeek", func(req *http.Request) util.JSONResponse {
var request api.PerformInboundPeekRequest
var response api.PerformInboundPeekResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformInboundPeek(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformInboundPeekPath,
httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek),
)
internalAPIMux.Handle(RoomserverPerformPeekPath,
httputil.MakeInternalAPI("performUnpeek", func(req *http.Request) util.JSONResponse {
var request api.PerformUnpeekRequest
var response api.PerformUnpeekResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformUnpeek(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformUnpeekPath,
httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek),
)
internalAPIMux.Handle(RoomserverPerformRoomUpgradePath,
httputil.MakeInternalAPI("performRoomUpgrade", func(req *http.Request) util.JSONResponse {
var request api.PerformRoomUpgradeRequest
var response api.PerformRoomUpgradeResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformRoomUpgrade(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformRoomUpgradePath,
httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade),
)
internalAPIMux.Handle(RoomserverPerformPublishPath,
httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse {
var request api.PerformPublishRequest
var response api.PerformPublishResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformPublish(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformPublishPath,
httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish),
)
internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath,
httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse {
var request api.PerformAdminEvacuateRoomRequest
var response api.PerformAdminEvacuateRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformAdminEvacuateRoom(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformAdminEvacuateRoomPath,
httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom),
)
internalAPIMux.Handle(RoomserverPerformAdminEvacuateUserPath,
httputil.MakeInternalAPI("performAdminEvacuateUser", func(req *http.Request) util.JSONResponse {
var request api.PerformAdminEvacuateUserRequest
var response api.PerformAdminEvacuateUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformAdminEvacuateUser(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformAdminEvacuateUserPath,
httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser),
)
internalAPIMux.Handle(
RoomserverQueryPublishedRoomsPath,
httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse {
var request api.QueryPublishedRoomsRequest
var response api.QueryPublishedRoomsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryPublishedRooms(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms),
)
internalAPIMux.Handle(
RoomserverQueryLatestEventsAndStatePath,
httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
var request api.QueryLatestEventsAndStateRequest
var response api.QueryLatestEventsAndStateResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState),
)
internalAPIMux.Handle(
RoomserverQueryStateAfterEventsPath,
httputil.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAfterEventsRequest
var response api.QueryStateAfterEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents),
)
internalAPIMux.Handle(
RoomserverQueryEventsByIDPath,
httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
var request api.QueryEventsByIDRequest
var response api.QueryEventsByIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID),
)
internalAPIMux.Handle(
RoomserverQueryMembershipForUserPath,
httputil.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipForUserRequest
var response api.QueryMembershipForUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser),
)
internalAPIMux.Handle(
RoomserverQueryMembershipsForRoomPath,
httputil.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipsForRoomRequest
var response api.QueryMembershipsForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom),
)
internalAPIMux.Handle(
RoomserverQueryServerJoinedToRoomPath,
httputil.MakeInternalAPI("queryServerJoinedToRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryServerJoinedToRoomRequest
var response api.QueryServerJoinedToRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom),
)
internalAPIMux.Handle(
RoomserverQueryServerAllowedToSeeEventPath,
httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
var request api.QueryServerAllowedToSeeEventRequest
var response api.QueryServerAllowedToSeeEventResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent),
)
internalAPIMux.Handle(
RoomserverQueryMissingEventsPath,
httputil.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryMissingEventsRequest
var response api.QueryMissingEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents),
)
internalAPIMux.Handle(
RoomserverQueryStateAndAuthChainPath,
httputil.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAndAuthChainRequest
var response api.QueryStateAndAuthChainResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain),
)
internalAPIMux.Handle(
RoomserverPerformBackfillPath,
httputil.MakeInternalAPI("PerformBackfill", func(req *http.Request) util.JSONResponse {
var request api.PerformBackfillRequest
var response api.PerformBackfillResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.PerformBackfill(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill),
)
internalAPIMux.Handle(
RoomserverPerformForgetPath,
httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse {
var request api.PerformForgetRequest
var response api.PerformForgetResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.PerformForget(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget),
)
internalAPIMux.Handle(
RoomserverQueryRoomVersionCapabilitiesPath,
httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse {
var request api.QueryRoomVersionCapabilitiesRequest
var response api.QueryRoomVersionCapabilitiesResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities),
)
internalAPIMux.Handle(
RoomserverQueryRoomVersionForRoomPath,
httputil.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryRoomVersionForRoomRequest
var response api.QueryRoomVersionForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom),
)
internalAPIMux.Handle(
RoomserverSetRoomAliasPath,
httputil.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.SetRoomAliasRequest
var response api.SetRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias),
)
internalAPIMux.Handle(
RoomserverGetRoomIDForAliasPath,
httputil.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse {
var request api.GetRoomIDForAliasRequest
var response api.GetRoomIDForAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias),
)
internalAPIMux.Handle(
RoomserverGetAliasesForRoomIDPath,
httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse {
var request api.GetAliasesForRoomIDRequest
var response api.GetAliasesForRoomIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID),
)
internalAPIMux.Handle(
RoomserverRemoveRoomAliasPath,
httputil.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.RemoveRoomAliasRequest
var response api.RemoveRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias),
)
internalAPIMux.Handle(RoomserverQueryCurrentStatePath,
httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse {
request := api.QueryCurrentStateRequest{}
response := api.QueryCurrentStateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryCurrentStatePath,
httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState),
)
internalAPIMux.Handle(RoomserverQueryRoomsForUserPath,
httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
request := api.QueryRoomsForUserRequest{}
response := api.QueryRoomsForUserResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryRoomsForUserPath,
httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser),
)
internalAPIMux.Handle(RoomserverQueryBulkStateContentPath,
httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
request := api.QueryBulkStateContentRequest{}
response := api.QueryBulkStateContentResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryBulkStateContentPath,
httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent),
)
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
request := api.QuerySharedUsersRequest{}
response := api.QuerySharedUsersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQuerySharedUsersPath,
httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers),
)
internalAPIMux.Handle(RoomserverQueryKnownUsersPath,
httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
request := api.QueryKnownUsersRequest{}
response := api.QueryKnownUsersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryKnownUsersPath,
httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers),
)
internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath,
httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse {
request := api.QueryServerBannedFromRoomRequest{}
response := api.QueryServerBannedFromRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryServerBannedFromRoomPath,
httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom),
)
internalAPIMux.Handle(RoomserverQueryAuthChainPath,
httputil.MakeInternalAPI("queryAuthChain", func(req *http.Request) util.JSONResponse {
request := api.QueryAuthChainRequest{}
response := api.QueryAuthChainResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryAuthChainPath,
httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain),
)
internalAPIMux.Handle(RoomserverQueryRestrictedJoinAllowed,
httputil.MakeInternalAPI("queryRestrictedJoinAllowed", func(req *http.Request) util.JSONResponse {
request := api.QueryRestrictedJoinAllowedRequest{}
response := api.QueryRestrictedJoinAllowedResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryRestrictedJoinAllowed(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryRestrictedJoinAllowed,
httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed),
)
internalAPIMux.Handle(
RoomserverQueryMembershipAtEventPath,
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent),
)
}

View file

@ -23,12 +23,11 @@ import (
"sync"
"time"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type StateResolutionStorage interface {
@ -124,6 +123,61 @@ func (v *StateResolution) LoadStateAtEvent(
return stateEntries, nil
}
func (v *StateResolution) LoadMembershipAtEvent(
ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID,
) (map[string][]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
defer span.Finish()
// De-dupe snapshotNIDs
snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs
for i := range eventIDs {
eventID := eventIDs[i]
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
}
if snapshotNID == 0 {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
}
snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID)
}
snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap))
for nid := range snapshotNIDMap {
snapshotNIDs = append(snapshotNIDs, nid)
}
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, snapshotNIDs)
if err != nil {
return nil, err
}
result := make(map[string][]types.StateEntry)
for _, stateBlockNIDList := range stateBlockNIDLists {
// Query the membership event for the user at the given stateblocks
stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{
{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
},
})
if err != nil {
return nil, err
}
evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID]
for _, evID := range evIDs {
for _, x := range stateEntryLists {
result[evID] = append(result[evID], x.StateEntries...)
}
}
}
return result, nil
}
// LoadStateAtEvent loads the full state of a room before a particular event.
func (v *StateResolution) LoadStateAtEventForHistoryVisibility(
ctx context.Context, eventID string,

View file

@ -23,7 +23,7 @@ import (
// DefaultRoomVersion contains the room version that will, by
// default, be used to create new rooms on this server.
func DefaultRoomVersion() gomatrixserverlib.RoomVersion {
return gomatrixserverlib.RoomVersionV6
return gomatrixserverlib.RoomVersionV9
}
// RoomVersions returns a map of all known room versions to this

View file

@ -164,9 +164,9 @@ func TestMSC2836(t *testing.T) {
// make everyone joined to each other's rooms
nopRsAPI := &testRoomserverAPI{
userToJoinedRooms: map[string][]string{
alice: []string{roomID},
bob: []string{roomID},
charlie: []string{roomID},
alice: {roomID},
bob: {roomID},
charlie: {roomID},
},
events: map[string]*gomatrixserverlib.HeaderedEvent{
eventA.EventID(): eventA,

View file

@ -45,9 +45,6 @@ const (
ConstCreateEventContentValueSpace = "m.space"
ConstSpaceChildEventType = "m.space.child"
ConstSpaceParentEventType = "m.space.parent"
ConstJoinRulePublic = "public"
ConstJoinRuleKnock = "knock"
ConstJoinRuleRestricted = "restricted"
)
type MSC2946ClientResponse struct {
@ -524,11 +521,11 @@ func (w *walker) authorisedServer(roomID string) bool {
return false
}
if rule == ConstJoinRulePublic || rule == ConstJoinRuleKnock {
if rule == gomatrixserverlib.Public || rule == gomatrixserverlib.Knock {
return true
}
if rule == ConstJoinRuleRestricted {
if rule == gomatrixserverlib.Restricted {
allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...)
}
}
@ -600,9 +597,9 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi
rule, ruleErr := joinRuleEv.JoinRule()
if ruleErr != nil {
util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", parentRoomID).Warn("failed to get join rule")
} else if rule == ConstJoinRulePublic || rule == ConstJoinRuleKnock {
} else if rule == gomatrixserverlib.Public || rule == gomatrixserverlib.Knock {
allowed = true
} else if rule == ConstJoinRuleRestricted {
} else if rule == gomatrixserverlib.Restricted {
allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")
// check parent is in the allowed set
for _, a := range allowedRoomIDs {
@ -639,7 +636,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi
func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) {
rule, _ := joinRuleEv.JoinRule()
if rule != ConstJoinRuleRestricted {
if rule != gomatrixserverlib.Restricted {
return nil
}
var jrContent gomatrixserverlib.JoinRuleContent

View file

@ -0,0 +1,217 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"math"
"time"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
"github.com/tidwall/gjson"
)
func init() {
prometheus.MustRegister(calculateHistoryVisibilityDuration)
}
// calculateHistoryVisibilityDuration stores the time it takes to
// calculate the history visibility. In polylith mode the roundtrip
// to the roomserver is included in this time.
var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "syncapi",
Name: "calculateHistoryVisibility_duration_millis",
Help: "How long it takes to calculate the history visibility",
Buckets: []float64{ // milliseconds
5, 10, 25, 50, 75, 100, 250, 500,
1000, 2000, 3000, 4000, 5000, 6000,
7000, 8000, 9000, 10000, 15000, 20000,
},
},
[]string{"api"},
)
var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{
gomatrixserverlib.WorldReadable: 0,
gomatrixserverlib.HistoryVisibilityShared: 1,
gomatrixserverlib.HistoryVisibilityInvited: 2,
gomatrixserverlib.HistoryVisibilityJoined: 3,
}
// eventVisibility contains the history visibility and membership state at a given event
type eventVisibility struct {
visibility gomatrixserverlib.HistoryVisibility
membershipAtEvent string
membershipCurrent string
}
// allowed checks the eventVisibility if the user is allowed to see the event.
// Rules as defined by https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
func (ev eventVisibility) allowed() (allowed bool) {
switch ev.visibility {
case gomatrixserverlib.HistoryVisibilityWorldReadable:
// If the history_visibility was set to world_readable, allow.
return true
case gomatrixserverlib.HistoryVisibilityJoined:
// If the users membership was join, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join {
return true
}
return false
case gomatrixserverlib.HistoryVisibilityShared:
// If the users membership was join, allow.
// If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join {
return true
}
return false
case gomatrixserverlib.HistoryVisibilityInvited:
// If the users membership was join, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join {
return true
}
if ev.membershipAtEvent == gomatrixserverlib.Invite {
return true
}
return false
default:
return false
}
}
// ApplyHistoryVisibilityFilter applies the room history visibility filter on gomatrixserverlib.HeaderedEvents.
// Returns the filtered events and an error, if any.
func ApplyHistoryVisibilityFilter(
ctx context.Context,
syncDB storage.Database,
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{},
userID, endpoint string,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
if len(events) == 0 {
return events, nil
}
start := time.Now()
// try to get the current membership of the user
membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64)
if err != nil {
return nil, err
}
// Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events))
visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID())
if err != nil {
return eventsFiltered, err
}
for _, ev := range events {
evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent
// Always include specific state events for /sync responses
if alwaysIncludeEventIDs != nil {
if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok {
eventsFiltered = append(eventsFiltered, ev)
continue
}
}
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) {
eventsFiltered = append(eventsFiltered, ev)
continue
}
// Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive
// of the old vs new.
// https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
if hisVis, err := ev.HistoryVisibility(); err == nil {
prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String()
oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)]
// if we can't get the previous history visibility, default to shared.
if !ok {
oldPrio = historyVisibilityPriority[gomatrixserverlib.HistoryVisibilityShared]
}
// no OK check, since this should have been validated when setting the value
newPrio := historyVisibilityPriority[hisVis]
if oldPrio < newPrio {
evVis.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis)
}
}
// do the actual check
allowed := evVis.allowed()
if allowed {
eventsFiltered = append(eventsFiltered, ev)
}
}
calculateHistoryVisibilityDuration.With(prometheus.Labels{"api": endpoint}).Observe(float64(time.Since(start).Milliseconds()))
return eventsFiltered, nil
}
// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership
// of `userID` at the given event.
// Returns an error if the roomserver can't calculate the memberships.
func visibilityForEvents(
ctx context.Context,
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
userID, roomID string,
) (map[string]eventVisibility, error) {
eventIDs := make([]string, len(events))
for i := range events {
eventIDs[i] = events[i].EventID()
}
result := make(map[string]eventVisibility, len(eventIDs))
// get the membership events for all eventIDs
membershipResp := &api.QueryMembershipAtEventResponse{}
err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{
RoomID: roomID,
EventIDs: eventIDs,
UserID: userID,
}, membershipResp)
if err != nil {
return result, err
}
// Create a map from eventID -> eventVisibility
for _, event := range events {
eventID := event.EventID()
vis := eventVisibility{
membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility,
}
membershipEvs, ok := membershipResp.Memberships[eventID]
if !ok {
result[eventID] = vis
continue
}
for _, ev := range membershipEvs {
membership, err := ev.Membership()
if err != nil {
return result, err
}
vis.membershipAtEvent = membership
}
result[eventID] = vis
}
return result, nil
}

View file

@ -31,7 +31,7 @@ import (
// DeviceOTKCounts adds one-time key counts to the /sync response
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
var queryRes keyapi.QueryOneTimeKeysResponse
keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
_ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
UserID: userID,
DeviceID: deviceID,
}, &queryRes)
@ -73,7 +73,7 @@ func DeviceListCatchup(
offset = int64(from)
}
var queryRes keyapi.QueryKeyChangesResponse
keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
_ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
Offset: offset,
ToOffset: toOffset,
}, &queryRes)

View file

@ -22,31 +22,41 @@ var (
type mockKeyAPI struct{}
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error {
return nil
}
func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {}
// PerformClaimKeys claims one-time keys for use in pre-key messages
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) {
func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) {
func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) {
func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error {
return nil
}
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error {
return nil
}
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
return nil
}
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
return nil
}
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) {
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error {
return nil
}
func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) {
func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error {
return nil
}
type mockRoomserverAPI struct {

View file

@ -21,10 +21,12 @@ import (
"fmt"
"net/http"
"strconv"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
@ -95,24 +97,6 @@ func Context(
ContainsURL: filter.ContainsURL,
}
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
state, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
// verify the user is allowed to see the context for this room/event
for _, x := range state {
var hisVis gomatrixserverlib.HistoryVisibility
hisVis, err = x.HistoryVisibility()
if err != nil {
continue
}
allowed := hisVis == gomatrixserverlib.WorldReadable || membershipRes.Membership == gomatrixserverlib.Join
if !allowed {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("User is not allowed to query context"),
}
}
}
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
if err != nil {
if err == sql.ErrNoRows {
@ -125,6 +109,24 @@ func Context(
return jsonerror.InternalServerError()
}
// verify the user is allowed to see the context for this room/event
startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (context)")
if len(filteredEvents) == 0 {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("User is not allowed to query context"),
}
}
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch before events")
@ -137,8 +139,27 @@ func Context(
return jsonerror.InternalServerError()
}
eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatAll)
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatAll)
startTime = time.Now()
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (context eventsBefore/eventsAfter)")
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
if err != nil {
logrus.WithError(err).Error("unable to fetch current room state")
return jsonerror.InternalServerError()
}
eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll)
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll)
newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache)
response := ContextRespsonse{
@ -162,6 +183,44 @@ func Context(
}
}
// applyHistoryVisibilityOnContextEvents is a helper function to avoid roundtrips to the roomserver
// by combining the events before and after the context event. Returns the filtered events,
// and an error, if any.
func applyHistoryVisibilityOnContextEvents(
ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
userID string,
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
eventIDsBefore := make(map[string]struct{}, len(eventsBefore))
eventIDsAfter := make(map[string]struct{}, len(eventsAfter))
// Remember before/after eventIDs, so we can restore them
// after applying history visibility checks
for _, ev := range eventsBefore {
eventIDsBefore[ev.EventID()] = struct{}{}
}
for _, ev := range eventsAfter {
eventIDsAfter[ev.EventID()] = struct{}{}
}
allEvents := append(eventsBefore, eventsAfter...)
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context")
if err != nil {
return nil, nil, err
}
// "Restore" events in the correct context
for _, ev := range filteredEvents {
if _, ok := eventIDsBefore[ev.EventID()]; ok {
filteredBefore = append(filteredBefore, ev)
}
if _, ok := eventIDsAfter[ev.EventID()]; ok {
filteredAfter = append(filteredAfter, ev)
}
}
return filteredBefore, filteredAfter, nil
}
func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if len(startEvents) > 0 {
start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID())

View file

@ -19,6 +19,7 @@ import (
"fmt"
"net/http"
"sort"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -28,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
@ -324,6 +326,9 @@ func (r *messagesReq) retrieveEvents() (
// reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events)
if err != nil {
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err
}
// Sort the events to ensure we send them in the right order.
if r.backwardOrdering {
@ -337,97 +342,18 @@ func (r *messagesReq) retrieveEvents() (
}
events = reversed(events)
}
events = r.filterHistoryVisible(events)
if len(events) == 0 {
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil
}
// Convert all of the events into client events.
clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll)
return clientEvents, start, end, err
}
func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
// user shouldn't see, we check the recent events and remove any prior to the join event of the user
// which is equiv to history_visibility: joined
joinEventIndex := -1
for i, ev := range events {
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(r.device.UserID) {
membership, _ := ev.Membership()
if membership == "join" {
joinEventIndex = i
break
}
}
}
var result []*gomatrixserverlib.HeaderedEvent
var eventsToCheck []*gomatrixserverlib.HeaderedEvent
if joinEventIndex != -1 {
if r.backwardOrdering {
result = events[:joinEventIndex+1]
eventsToCheck = append(eventsToCheck, result[0])
} else {
result = events[joinEventIndex:]
eventsToCheck = append(eventsToCheck, result[len(result)-1])
}
} else {
eventsToCheck = []*gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]}
result = events
}
// make sure the user was in the room for both the earliest and latest events, we need this because
// some backpagination results will not have the join event (e.g if they hit /messages at the join event itself)
wasJoined := true
for _, ev := range eventsToCheck {
var queryRes api.QueryStateAfterEventsResponse
err := r.rsAPI.QueryStateAfterEvents(r.ctx, &api.QueryStateAfterEventsRequest{
RoomID: ev.RoomID(),
PrevEventIDs: ev.PrevEventIDs(),
StateToFetch: []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomMember, StateKey: r.device.UserID},
{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""},
},
}, &queryRes)
if err != nil {
wasJoined = false
break
}
var hisVisEvent, membershipEvent *gomatrixserverlib.HeaderedEvent
for i := range queryRes.StateEvents {
switch queryRes.StateEvents[i].Type() {
case gomatrixserverlib.MRoomMember:
membershipEvent = queryRes.StateEvents[i]
case gomatrixserverlib.MRoomHistoryVisibility:
hisVisEvent = queryRes.StateEvents[i]
}
}
if hisVisEvent == nil {
return events // apply no filtering as it defaults to Shared.
}
hisVis, _ := hisVisEvent.HistoryVisibility()
if hisVis == "shared" || hisVis == "world_readable" {
return events // apply no filtering
}
if membershipEvent == nil {
wasJoined = false
break
}
membership, err := membershipEvent.Membership()
if err != nil {
wasJoined = false
break
}
if membership != "join" {
wasJoined = false
break
}
}
if !wasJoined {
util.GetLogger(r.ctx).WithField("num_events", len(events)).Warnf("%s was not joined to room during these events, omitting them", r.device.UserID)
return []*gomatrixserverlib.HeaderedEvent{}
}
return result
// Apply room history visibility filter
startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": r.roomID,
}).Debug("applied history visibility (messages)")
return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err
}
func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {

View file

@ -161,6 +161,10 @@ type Database interface {
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
}
type Presence interface {

View file

@ -17,7 +17,10 @@ package deltas
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
@ -31,6 +34,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T
return nil
}
// UpSetHistoryVisibility sets the history visibility for already stored events.
// Requires current_room_state and output_room_events to be created.
func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error {
// get the current room history visibilities
historyVisibilities, err := currentHistoryVisibilities(ctx, tx)
if err != nil {
return err
}
// update the history visibility
for roomID, hisVis := range historyVisibilities {
_, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1
WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID)
if err != nil {
return fmt.Errorf("failed to update history visibility: %w", err)
}
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
@ -39,9 +63,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
// currentHistoryVisibilities returns a map from roomID to current history visibility.
// If the history visibility was changed after room creation, defaults to joined.
func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) {
rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state
WHERE type = 'm.room.history_visibility' AND state_key = '';
`)
if err != nil {
return nil, fmt.Errorf("failed to query current room state: %w", err)
}
defer rows.Close() // nolint: errcheck
var eventBytes []byte
var roomID string
var event gomatrixserverlib.HeaderedEvent
var hisVis gomatrixserverlib.HistoryVisibility
historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility)
for rows.Next() {
if err = rows.Scan(&roomID, &eventBytes); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
if err = json.Unmarshal(eventBytes, &event); err != nil {
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined
if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 {
historyVisibilities[roomID] = hisVis
}
}
return historyVisibilities, nil
}
func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility;

View file

@ -66,10 +66,14 @@ const selectMembershipCountSQL = "" +
const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
selectHeroesStmt *sql.Stmt
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
selectHeroesStmt *sql.Stmt
selectMembershipForUserStmt *sql.Stmt
}
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -82,6 +86,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectHeroesStmt, selectHeroesSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
}.Prepare(db)
}
@ -132,3 +137,20 @@ func (s *membershipsStatements) SelectHeroes(
}
return heroes, rows.Err()
}
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
func (s *membershipsStatements) SelectMembershipForUser(
ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64,
) (membership string, topologyPos int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos)
if err != nil {
if err == sql.ErrNoRows {
return "leave", 0, nil
}
return "", 0, err
}
return membership, topologyPos, nil
}

View file

@ -191,10 +191,12 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
})
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
},
)
err = m.Up(context.Background())
if err != nil {
return nil, err

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
)
@ -97,6 +98,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil {
return nil, err
}
// apply migrations which need multiple tables
m := sqlutil.NewMigrator(d.db)
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: set history visibility for existing events",
Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created.
},
)
err = m.Up(base.Context())
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
Writer: d.writer,

View file

@ -231,7 +231,7 @@ func (d *Database) AddPeek(
return
}
// DeletePeeks tracks the fact that a user has stopped peeking from the specified
// DeletePeek tracks the fact that a user has stopped peeking from the specified
// device. If the peeks was successfully deleted this returns the stream ID it was
// stored at. Returns an error if there was a problem communicating with the database.
func (d *Database) DeletePeek(
@ -372,6 +372,7 @@ func (d *Database) WriteEvent(
) (pduPosition types.StreamPosition, returnErr error) {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
ev.Visibility = historyVisibility
pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
)
@ -563,7 +564,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
return err
}
// Retrieve the backward topology position, i.e. the position of the
// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
// oldest event in the room's topology.
func (d *Database) GetBackwardTopologyPos(
ctx context.Context,
@ -674,7 +675,7 @@ func (d *Database) fetchMissingStateEvents(
return events, nil
}
// getStateDeltas returns the state deltas between fromPos and toPos,
// GetStateDeltas returns the state deltas between fromPos and toPos,
// exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it.
@ -812,7 +813,7 @@ func (d *Database) GetStateDeltas(
return deltas, joinedRoomIDs, nil
}
// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms.
@ -1039,37 +1040,41 @@ func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID s
return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to)
}
func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
}
func (s *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
}
func (s *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
}
func (s *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
return s.Ignores.SelectIgnores(ctx, userID)
func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
return d.Ignores.SelectIgnores(ctx, userID)
}
func (s *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
return s.Ignores.UpsertIgnores(ctx, userID, ignores)
func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
return d.Ignores.UpsertIgnores(ctx, userID, ignores)
}
func (s *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
return s.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
return d.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
}
func (s *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return s.Presence.GetPresenceForUser(ctx, nil, userID)
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, nil, userID)
}
func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return s.Presence.GetPresenceAfter(ctx, nil, after, filter)
func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return d.Presence.GetPresenceAfter(ctx, nil, after, filter)
}
func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return s.Presence.GetMaxPresenceID(ctx, nil)
func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return d.Presence.GetMaxPresenceID(ctx, nil)
}
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
}

View file

@ -17,7 +17,10 @@ package deltas
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
@ -37,6 +40,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T
return nil
}
// UpSetHistoryVisibility sets the history visibility for already stored events.
// Requires current_room_state and output_room_events to be created.
func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error {
// get the current room history visibilities
historyVisibilities, err := currentHistoryVisibilities(ctx, tx)
if err != nil {
return err
}
// update the history visibility
for roomID, hisVis := range historyVisibilities {
_, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1
WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID)
if err != nil {
return fmt.Errorf("failed to update history visibility: %w", err)
}
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
// Required for unit tests, as otherwise a duplicate column error will show up.
@ -51,9 +75,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
// currentHistoryVisibilities returns a map from roomID to current history visibility.
// If the history visibility was changed after room creation, defaults to joined.
func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) {
rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state
WHERE type = 'm.room.history_visibility' AND state_key = '';
`)
if err != nil {
return nil, fmt.Errorf("failed to query current room state: %w", err)
}
defer rows.Close() // nolint: errcheck
var eventBytes []byte
var roomID string
var event gomatrixserverlib.HeaderedEvent
var hisVis gomatrixserverlib.HistoryVisibility
historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility)
for rows.Next() {
if err = rows.Scan(&roomID, &eventBytes); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
if err = json.Unmarshal(eventBytes, &event); err != nil {
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined
if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 {
historyVisibilities[roomID] = hisVis
}
}
return historyVisibilities, nil
}
func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists.
_, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")

View file

@ -66,11 +66,15 @@ const selectMembershipCountSQL = "" +
const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
type membershipsStatements struct {
db *sql.DB
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
//selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic
selectMembershipForUserStmt *sql.Stmt
}
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -84,6 +88,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
return s, sqlutil.StatementList{
{&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
// {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
}.Prepare(db)
}
@ -148,3 +153,20 @@ func (s *membershipsStatements) SelectHeroes(
}
return heroes, rows.Err()
}
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
func (s *membershipsStatements) SelectMembershipForUser(
ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64,
) (membership string, topologyPos int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos)
if err != nil {
if err == sql.ErrNoRows {
return "leave", 0, nil
}
return "", 0, err
}
return membership, topologyPos, nil
}

View file

@ -139,10 +139,12 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
})
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
},
)
err = m.Up(context.Background())
if err != nil {
return nil, err

View file

@ -16,12 +16,14 @@
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
)
// SyncServerDatasource represents a sync server datasource which manages
@ -41,13 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
if err = d.prepare(); err != nil {
if err = d.prepare(base.Context()); err != nil {
return nil, err
}
return &d, nil
}
func (d *SyncServerDatasource) prepare() (err error) {
func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
if err = d.streamID.Prepare(d.db); err != nil {
return err
}
@ -107,6 +109,19 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil {
return err
}
// apply migrations which need multiple tables
m := sqlutil.NewMigrator(d.db)
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: set history visibility for existing events",
Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created.
},
)
err = m.Up(ctx)
if err != nil {
return err
}
d.Database = shared.Database{
DB: d.db,
Writer: d.writer,

View file

@ -12,20 +12,22 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
)
var ctx = context.Background()
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func(), func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewSyncServerDatasource(nil, &config.DatabaseOptions{
base, closeBase := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewSyncServerDatasource(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
return db, close
return db, close, closeBase
}
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
@ -51,8 +53,9 @@ func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType)
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
MustWriteEvents(t, db, r.Events())
})
}
@ -60,8 +63,9 @@ func TestWriteEvents(t *testing.T) {
// These tests assert basic functionality of RecentEvents for PDUs
func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
alice := test.NewUser(t)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
@ -163,8 +167,9 @@ func TestRecentEventsPDU(t *testing.T) {
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ {
@ -404,8 +409,9 @@ func TestSendToDeviceBehaviour(t *testing.T) {
bob := test.NewUser(t)
deviceID := "one"
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
// At this point there should be no messages. We haven't sent anything
// yet.
_, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)

View file

@ -185,6 +185,7 @@ type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
}
type NotificationData interface {

View file

@ -10,10 +10,13 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"go.uber.org/atomic"
@ -123,7 +126,7 @@ func (p *PDUStreamProvider) CompleteSync(
defer reqWaitGroup.Done()
jr, jerr := p.getJoinResponseForCompleteSync(
ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device,
ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
)
if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
@ -149,7 +152,7 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted {
var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync(
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device,
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
)
if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -281,12 +284,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
}
}
if len(recentEvents) > 0 {
updateLatestPosition(recentEvents[len(recentEvents)-1].EventID())
}
if len(delta.StateEvents) > 0 {
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
}
if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers(
@ -306,6 +303,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
}
// Applies the history visibility rules
events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
if len(events) > 0 {
updateLatestPosition(events[len(events)-1].EventID())
}
if len(delta.StateEvents) > 0 {
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
}
switch delta.Membership {
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
@ -313,14 +323,17 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition)
}
jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.RoomID] = *jr
case gomatrixserverlib.Peek:
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
@ -330,12 +343,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
fallthrough // transitions to leave are the same as ban
case gomatrixserverlib.Ban:
// TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room.
lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Leave[delta.RoomID] = *lr
}
@ -343,6 +356,41 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
return latestPosition, nil
}
// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make
// sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter(
ctx context.Context,
db storage.Database,
rsAPI roomserverAPI.SyncRoomserverAPI,
roomID, userID string,
limit int,
recentEvents []*gomatrixserverlib.HeaderedEvent,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We need to make sure we always include the latest states events, if they are in the timeline.
// We grep at least limit * 2 events, to ensure we really get the needed events.
stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
if err != nil {
// Not a fatal error, we can continue without the stateEvents,
// they are only needed if there are state events in the timeline.
logrus.WithError(err).Warnf("failed to get current room state")
}
alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents))
for _, ev := range stateEvents {
alwaysIncludeIDs[ev.EventID()] = struct{}{}
}
startTime := time.Now()
events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
if err != nil {
return nil, err
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (sync)")
return events, nil
}
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
@ -390,6 +438,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
eventFilter *gomatrixserverlib.RoomEventFilter,
wantFullState bool,
device *userapi.Device,
isPeek bool,
) (jr *types.JoinResponse, err error) {
jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events.
@ -404,33 +453,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return
}
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
// user shouldn't see, we check the recent events and remove any prior to the join event of the user
// which is equiv to history_visibility: joined
joinEventIndex := -1
for i := len(recentStreamEvents) - 1; i >= 0; i-- {
ev := recentStreamEvents[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(device.UserID) {
membership, _ := ev.Membership()
if membership == "join" {
joinEventIndex = i
if i > 0 {
// the create event happens before the first join, so we should cut it at that point instead
if recentStreamEvents[i-1].Type() == gomatrixserverlib.MRoomCreate && recentStreamEvents[i-1].StateKeyEquals("") {
joinEventIndex = i - 1
break
}
}
break
}
}
}
if joinEventIndex != -1 {
// cut all events earlier than the join (but not the join itself)
recentStreamEvents = recentStreamEvents[joinEventIndex:]
limited = false // so clients know not to try to backpaginate
}
// Work our way through the timeline events and pick out the event IDs
// of any state events that appear in the timeline. We'll specifically
// exclude them at the next step, so that we don't get duplicate state
@ -474,6 +496,19 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
events := recentEvents
// Only apply history visibility checks if the response is for joined rooms
if !isPeek {
events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
}
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
limited = limited && len(events) == len(recentEvents)
if stateFilter.LazyLoadMembers {
if err != nil {
return nil, err
@ -488,8 +523,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
return jr, nil
}

View file

@ -12,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/producers"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
@ -54,6 +55,16 @@ func (s *syncRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *rsap
return nil
}
func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsapi.QueryMembershipForUserRequest, res *rsapi.QueryMembershipForUserResponse) error {
res.IsRoomForgotten = false
res.RoomExists = true
return nil
}
func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error {
return nil
}
type syncUserAPI struct {
userapi.SyncUserAPI
accounts []userapi.Device
@ -78,10 +89,11 @@ type syncKeyAPI struct {
keyapi.SyncKeyAPI
}
func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
return nil
}
func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
return nil
}
func TestSyncAPIAccessTokens(t *testing.T) {
@ -106,7 +118,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events())
msgs := toNATSMsgs(t, base, room.Events()...)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
testrig.MustPublishMsgs(t, jsctx, msgs...)
@ -199,7 +211,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
// m.room.power_levels
// m.room.join_rules
// m.room.history_visibility
msgs := toNATSMsgs(t, base, room.Events())
msgs := toNATSMsgs(t, base, room.Events()...)
sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
for i, msg := range msgs {
@ -314,6 +326,174 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
}
// This is mainly what Sytest is doing in "test_history_visibility"
func TestMessageHistoryVisibility(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testHistoryVisibility(t, dbType)
})
}
func testHistoryVisibility(t *testing.T, dbType test.DBType) {
type result struct {
seeWithoutJoin bool
seeBeforeJoin bool
seeAfterInvite bool
}
// create the users
alice := test.NewUser(t)
bob := test.NewUser(t)
bobDev := userapi.Device{
ID: "BOBID",
UserID: bob.ID,
AccessToken: "BOD_BEARER_TOKEN",
DisplayName: "BOB",
}
ctx := context.Background()
// check guest and normal user accounts
for _, accType := range []userapi.AccountType{userapi.AccountTypeGuest, userapi.AccountTypeUser} {
testCases := []struct {
historyVisibility gomatrixserverlib.HistoryVisibility
wantResult result
}{
{
historyVisibility: gomatrixserverlib.HistoryVisibilityWorldReadable,
wantResult: result{
seeWithoutJoin: true,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityShared,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityInvited,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: false,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityJoined,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: false,
seeAfterInvite: false,
},
},
}
bobDev.AccountType = accType
userType := "guest"
if accType == userapi.AccountTypeUser {
userType = "real user"
}
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
// Use the actual internal roomserver API
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{})
for _, tc := range testCases {
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
t.Run(testname, func(t *testing.T) {
// create a room with the given visibility
room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility))
// send the events/messages to NATS to create the rooms
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)})
eventsToSend := append(room.Events(), beforeJoinEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// There is only one event, we expect only to be able to see this, if the room is world_readable
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
// We only care about the returned events at this point
var res struct {
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
}
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
verifyEventVisible(t, tc.wantResult.seeWithoutJoin, beforeJoinEv, res.Chunk)
// Create invite, a message, join the room and create another message.
inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID))
afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)})
joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID))
msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)})
eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// Verify the messages after/before invite are visible or not
w = httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
// verify results
verifyEventVisible(t, tc.wantResult.seeBeforeJoin, beforeJoinEv, res.Chunk)
verifyEventVisible(t, tc.wantResult.seeAfterInvite, afterInviteEv, res.Chunk)
})
}
}
}
func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []gomatrixserverlib.ClientEvent) {
t.Helper()
if wantVisible {
for _, ev := range chunk {
if ev.EventID == wantVisibleEvent.EventID() {
return
}
}
t.Fatalf("expected to see event %s but didn't: %+v", wantVisibleEvent.EventID(), chunk)
} else {
for _, ev := range chunk {
if ev.EventID == wantVisibleEvent.EventID() {
t.Fatalf("expected not to see event %s: %+v", wantVisibleEvent.EventID(), string(ev.Content))
}
}
}
}
func TestSendToDevice(t *testing.T) {
test.WithAllDatabases(t, testSendToDevice)
}
@ -447,7 +627,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
}
}
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
result := make([]*nats.Msg, len(input))
for i, ev := range input {
var addsStateIDs []string
@ -459,6 +639,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev,
AddsStateEventIDs: addsStateIDs,
HistoryVisibility: ev.Visibility,
},
})
}

View file

@ -49,7 +49,3 @@ Notifications can be viewed with GET /notifications
If remote user leaves room we no longer receive device updates
Guest users can join guest_access rooms
# You'll be shocked to discover this is flakey too
Inbound /v1/send_join rejects joins from other servers

View file

@ -111,8 +111,6 @@ Newly joined room includes presence in incremental sync
User is offline if they set_presence=offline in their sync
Changes to state are included in an incremental sync
A change to displayname should appear in incremental /sync
Current state appears in timeline in private history
Current state appears in timeline in private history with many messages before
Rooms a user is invited to appear in an initial sync
Rooms a user is invited to appear in an incremental sync
Sync can be polled for updates
@ -459,7 +457,6 @@ After changing password, a different session no longer works by default
Read markers appear in incremental v2 /sync
Read markers appear in initial v2 /sync
Read markers can be updated
Local users can peek into world_readable rooms by room ID
We can't peek into rooms with shared history_visibility
We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility
@ -721,4 +718,30 @@ Setting state twice is idempotent
Joining room twice is idempotent
Inbound federation can return missing events for shared visibility
Inbound federation ignores redactions from invalid servers room > v3
Newly joined room includes presence in incremental sync
Joining room twice is idempotent
Getting messages going forward is limited for a departed room (SPEC-216)
m.room.history_visibility == "shared" allows/forbids appropriately for Guest users
m.room.history_visibility == "invited" allows/forbids appropriately for Guest users
m.room.history_visibility == "default" allows/forbids appropriately for Guest users
m.room.history_visibility == "shared" allows/forbids appropriately for Real users
m.room.history_visibility == "invited" allows/forbids appropriately for Real users
m.room.history_visibility == "default" allows/forbids appropriately for Real users
Guest users can sync from world_readable guest_access rooms if joined
Guest users can sync from shared guest_access rooms if joined
Guest users can sync from invited guest_access rooms if joined
Guest users can sync from joined guest_access rooms if joined
Guest users can sync from default guest_access rooms if joined
Real users can sync from world_readable guest_access rooms if joined
Real users can sync from shared guest_access rooms if joined
Real users can sync from invited guest_access rooms if joined
Real users can sync from joined guest_access rooms if joined
Real users can sync from default guest_access rooms if joined
Only see history_visibility changes on boundaries
Current state appears in timeline in private history
Current state appears in timeline in private history with many messages before
Local users can peek into world_readable rooms by room ID
Newly joined room includes presence in incremental sync
User in private room doesn't appear in user directory
User joining then leaving public room appears and dissappears from directory
User in remote room doesn't appear in user directory after server left room
User in shared private room does appear in user directory until leave

View file

@ -37,10 +37,11 @@ var (
)
type Room struct {
ID string
Version gomatrixserverlib.RoomVersion
preset Preset
creator *User
ID string
Version gomatrixserverlib.RoomVersion
preset Preset
visibility gomatrixserverlib.HistoryVisibility
creator *User
authEvents gomatrixserverlib.AuthEvents
currentState map[string]*gomatrixserverlib.HeaderedEvent
@ -61,6 +62,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9,
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
visibility: gomatrixserverlib.HistoryVisibilityShared,
}
for _, m := range modifiers {
m(t, r)
@ -97,10 +99,14 @@ func (r *Room) insertCreateEvents(t *testing.T) {
fallthrough
case PresetPrivateChat:
joinRule.JoinRule = "invite"
hisVis.HistoryVisibility = "shared"
hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
case PresetPublicChat:
joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared"
hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
}
if r.visibility != "" {
hisVis.HistoryVisibility = r.visibility
}
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
@ -183,7 +189,9 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
}
return ev.Headered(r.Version)
headeredEvent := ev.Headered(r.Version)
headeredEvent.Visibility = r.visibility
return headeredEvent
}
// Add a new event to this room DAG. Not thread-safe.
@ -242,6 +250,12 @@ func RoomPreset(p Preset) roomModifier {
}
}
func RoomHistoryVisibility(vis gomatrixserverlib.HistoryVisibility) roomModifier {
return func(t *testing.T, r *Room) {
r.visibility = vis
}
}
func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
return func(t *testing.T, r *Room) {
r.Version = ver

View file

@ -20,6 +20,7 @@ import (
"sync/atomic"
"testing"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
@ -45,7 +46,8 @@ var (
)
type User struct {
ID string
ID string
accountType api.AccountType
// key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
@ -62,6 +64,12 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve
}
}
func WithAccountType(accountType api.AccountType) UserOpt {
return func(u *User) {
u.accountType = accountType
}
}
func NewUser(t *testing.T, opts ...UserOpt) *User {
counter := atomic.AddInt64(&userIDCounter, 1)
var u User

View file

@ -100,7 +100,7 @@ type ClientUserAPI interface {
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
@ -334,8 +334,9 @@ type PerformAccountCreationResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set.
Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set.
LogoutDevices bool // Optional: Whether to log out all user devices.
}
// PerformAccountCreationResponse is the response for PerformAccountCreation

View file

@ -94,9 +94,10 @@ func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *Per
util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) {
t.Impl.QueryKeyBackup(ctx, req, res)
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error {
err := t.Impl.QueryKeyBackup(ctx, req, res)
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error {
err := t.Impl.QueryProfile(ctx, req, res)

View file

@ -139,6 +139,11 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
if req.LogoutDevices {
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
return err
}
}
res.PasswordUpdated = true
return nil
}
@ -192,7 +197,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
}
deleteRes := &keyapi.PerformDeleteKeysResponse{}
a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes)
if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
return err
}
if err := deleteRes.Error; err != nil {
return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err)
}
@ -211,10 +218,12 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er
}
var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: userID,
DeviceKeys: deviceKeys,
}, &uploadRes)
}, &uploadRes); err != nil {
return err
}
if uploadRes.Error != nil {
return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error)
}
@ -268,7 +277,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key
var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: req.RequestingUserID,
DeviceKeys: []keyapi.DeviceKeys{
{
@ -279,7 +288,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
},
},
OnlyDisplayNameUpdates: true,
}, &uploadRes)
}, &uploadRes); err != nil {
return err
}
if uploadRes.Error != nil {
return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error)
}
@ -479,7 +490,9 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
}
evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{}
a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes)
if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil {
return err
}
if err := evacuateRes.Error; err != nil {
logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation")
}
@ -538,9 +551,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
if req.Version == "" {
res.BadInput = true
res.Error = "must specify a version to delete"
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
@ -549,9 +559,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = exists
res.Version = req.Version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
// Create metadata
@ -562,9 +569,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = err == nil
res.Version = version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
// Update metadata
@ -575,16 +579,10 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = err == nil
res.Version = req.Version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
// Upload Keys for a specific version metadata
a.uploadBackupKeys(ctx, req, res)
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
@ -627,16 +625,16 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
res.KeyETag = etag
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error {
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
res.Exists = false
return
return nil
}
res.Error = fmt.Sprintf("failed to query key backup: %s", err)
return
return nil
}
res.Algorithm = algorithm
res.AuthData = authData
@ -648,15 +646,16 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
return nil
}
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return
return nil
}
res.Keys = result
return nil
}
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@ -84,11 +83,10 @@ type httpUserInternalAPI struct {
}
func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData")
defer span.Finish()
apiURL := h.apiURL + InputAccountDataPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"InputAccountData", h.apiURL+InputAccountDataPath,
h.httpClient, ctx, req, res,
)
}
func (h *httpUserInternalAPI) PerformAccountCreation(
@ -96,11 +94,10 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
request *api.PerformAccountCreationRequest,
response *api.PerformAccountCreationResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountCreation")
defer span.Finish()
apiURL := h.apiURL + PerformAccountCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformAccountCreation", h.apiURL+PerformAccountCreationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformPasswordUpdate(
@ -108,11 +105,10 @@ func (h *httpUserInternalAPI) PerformPasswordUpdate(
request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformPasswordUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformPasswordUpdate", h.apiURL+PerformPasswordUpdatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformDeviceCreation(
@ -120,11 +116,10 @@ func (h *httpUserInternalAPI) PerformDeviceCreation(
request *api.PerformDeviceCreationRequest,
response *api.PerformDeviceCreationResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceCreation")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformDeviceCreation", h.apiURL+PerformDeviceCreationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformDeviceDeletion(
@ -132,47 +127,54 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion(
request *api.PerformDeviceDeletionRequest,
response *api.PerformDeviceDeletionResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformDeviceDeletion", h.apiURL+PerformDeviceDeletionPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformLastSeenUpdate(
ctx context.Context,
req *api.PerformLastSeenUpdateRequest,
res *api.PerformLastSeenUpdateResponse,
request *api.PerformLastSeenUpdateRequest,
response *api.PerformLastSeenUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLastSeen")
defer span.Finish()
apiURL := h.apiURL + PerformLastSeenUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"PerformLastSeen", h.apiURL+PerformLastSeenUpdatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformDeviceUpdate(
ctx context.Context,
request *api.PerformDeviceUpdateRequest,
response *api.PerformDeviceUpdateResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformDeviceUpdate", h.apiURL+PerformDeviceUpdatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountDeactivation")
defer span.Finish()
apiURL := h.apiURL + PerformAccountDeactivationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformAccountDeactivation(
ctx context.Context,
request *api.PerformAccountDeactivationRequest,
response *api.PerformAccountDeactivationResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformAccountDeactivation", h.apiURL+PerformAccountDeactivationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation")
defer span.Finish()
apiURL := h.apiURL + PerformOpenIDTokenCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(
ctx context.Context,
request *api.PerformOpenIDTokenCreationRequest,
response *api.PerformOpenIDTokenCreationResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformOpenIDTokenCreation", h.apiURL+PerformOpenIDTokenCreationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryProfile(
@ -180,11 +182,10 @@ func (h *httpUserInternalAPI) QueryProfile(
request *api.QueryProfileRequest,
response *api.QueryProfileResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryProfile")
defer span.Finish()
apiURL := h.apiURL + QueryProfilePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryProfile", h.apiURL+QueryProfilePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryDeviceInfos(
@ -192,11 +193,10 @@ func (h *httpUserInternalAPI) QueryDeviceInfos(
request *api.QueryDeviceInfosRequest,
response *api.QueryDeviceInfosResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos")
defer span.Finish()
apiURL := h.apiURL + QueryDeviceInfosPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryDeviceInfos", h.apiURL+QueryDeviceInfosPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryAccessToken(
@ -204,72 +204,87 @@ func (h *httpUserInternalAPI) QueryAccessToken(
request *api.QueryAccessTokenRequest,
response *api.QueryAccessTokenResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccessToken")
defer span.Finish()
apiURL := h.apiURL + QueryAccessTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryAccessToken", h.apiURL+QueryAccessTokenPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDevices")
defer span.Finish()
apiURL := h.apiURL + QueryDevicesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryDevices(
ctx context.Context,
request *api.QueryDevicesRequest,
response *api.QueryDevicesResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryDevices", h.apiURL+QueryDevicesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccountData")
defer span.Finish()
apiURL := h.apiURL + QueryAccountDataPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryAccountData(
ctx context.Context,
request *api.QueryAccountDataRequest,
response *api.QueryAccountDataResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountData", h.apiURL+QueryAccountDataPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles")
defer span.Finish()
apiURL := h.apiURL + QuerySearchProfilesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QuerySearchProfiles(
ctx context.Context,
request *api.QuerySearchProfilesRequest,
response *api.QuerySearchProfilesResponse,
) error {
return httputil.CallInternalRPCAPI(
"QuerySearchProfiles", h.apiURL+QuerySearchProfilesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken")
defer span.Finish()
apiURL := h.apiURL + QueryOpenIDTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryOpenIDToken(
ctx context.Context,
request *api.QueryOpenIDTokenRequest,
response *api.QueryOpenIDTokenResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryOpenIDToken", h.apiURL+QueryOpenIDTokenPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup")
defer span.Finish()
apiURL := h.apiURL + PerformKeyBackupPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = err.Error()
}
return nil
}
func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
defer span.Finish()
apiURL := h.apiURL + QueryKeyBackupPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = err.Error()
}
func (h *httpUserInternalAPI) PerformKeyBackup(
ctx context.Context,
request *api.PerformKeyBackupRequest,
response *api.PerformKeyBackupResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformKeyBackup", h.apiURL+PerformKeyBackupPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications")
defer span.Finish()
func (h *httpUserInternalAPI) QueryKeyBackup(
ctx context.Context,
request *api.QueryKeyBackupRequest,
response *api.QueryKeyBackupResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryKeyBackup", h.apiURL+QueryKeyBackupPath,
h.httpClient, ctx, request, response,
)
}
return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res)
func (h *httpUserInternalAPI) QueryNotifications(
ctx context.Context,
request *api.QueryNotificationsRequest,
response *api.QueryNotificationsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryNotifications", h.apiURL+QueryNotificationsPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformPusherSet(
@ -277,27 +292,32 @@ func (h *httpUserInternalAPI) PerformPusherSet(
request *api.PerformPusherSetRequest,
response *struct{},
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet")
defer span.Finish()
apiURL := h.apiURL + PerformPusherSetPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformPusherSet", h.apiURL+PerformPusherSetPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformPusherDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformPusherDeletion(
ctx context.Context,
request *api.PerformPusherDeletionRequest,
response *struct{},
) error {
return httputil.CallInternalRPCAPI(
"PerformPusherDeletion", h.apiURL+PerformPusherDeletionPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers")
defer span.Finish()
apiURL := h.apiURL + QueryPushersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryPushers(
ctx context.Context,
request *api.QueryPushersRequest,
response *api.QueryPushersResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryPushers", h.apiURL+QueryPushersPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformPushRulesPut(
@ -305,89 +325,117 @@ func (h *httpUserInternalAPI) PerformPushRulesPut(
request *api.PerformPushRulesPutRequest,
response *struct{},
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut")
defer span.Finish()
apiURL := h.apiURL + PerformPushRulesPutPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformPushRulesPut", h.apiURL+PerformPushRulesPutPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules")
defer span.Finish()
apiURL := h.apiURL + QueryPushRulesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryPushRules(
ctx context.Context,
request *api.QueryPushRulesRequest,
response *api.QueryPushRulesResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryPushRules", h.apiURL+QueryPushRulesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetAvatarURLPath)
defer span.Finish()
apiURL := h.apiURL + PerformSetAvatarURLPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) SetAvatarURL(
ctx context.Context,
request *api.PerformSetAvatarURLRequest,
response *api.PerformSetAvatarURLResponse,
) error {
return httputil.CallInternalRPCAPI(
"SetAvatarURL", h.apiURL+PerformSetAvatarURLPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryNumericLocalpartPath)
defer span.Finish()
apiURL := h.apiURL + QueryNumericLocalpartPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, struct{}{}, res)
func (h *httpUserInternalAPI) QueryNumericLocalpart(
ctx context.Context,
response *api.QueryNumericLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
h.httpClient, ctx, &struct{}{}, response,
)
}
func (h *httpUserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountAvailabilityPath)
defer span.Finish()
apiURL := h.apiURL + QueryAccountAvailabilityPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryAccountAvailability(
ctx context.Context,
request *api.QueryAccountAvailabilityRequest,
response *api.QueryAccountAvailabilityResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountAvailability", h.apiURL+QueryAccountAvailabilityPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountByPasswordPath)
defer span.Finish()
apiURL := h.apiURL + QueryAccountByPasswordPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryAccountByPassword(
ctx context.Context,
request *api.QueryAccountByPasswordRequest,
response *api.QueryAccountByPasswordResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountByPassword", h.apiURL+QueryAccountByPasswordPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetDisplayNamePath)
defer span.Finish()
apiURL := h.apiURL + PerformSetDisplayNamePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) SetDisplayName(
ctx context.Context,
request *api.PerformUpdateDisplayNameRequest,
response *struct{},
) error {
return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForThreePIDPath)
defer span.Finish()
apiURL := h.apiURL + QueryLocalpartForThreePIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryLocalpartForThreePID(
ctx context.Context,
request *api.QueryLocalpartForThreePIDRequest,
response *api.QueryLocalpartForThreePIDResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryLocalpartForThreePID", h.apiURL+QueryLocalpartForThreePIDPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryThreePIDsForLocalpartPath)
defer span.Finish()
apiURL := h.apiURL + QueryThreePIDsForLocalpartPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(
ctx context.Context,
request *api.QueryThreePIDsForLocalpartRequest,
response *api.QueryThreePIDsForLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryThreePIDsForLocalpart", h.apiURL+QueryThreePIDsForLocalpartPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetThreePIDPath)
defer span.Finish()
apiURL := h.apiURL + PerformForgetThreePIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformForgetThreePID(
ctx context.Context,
request *api.PerformForgetThreePIDRequest,
response *struct{},
) error {
return httputil.CallInternalRPCAPI(
"PerformForgetThreePID", h.apiURL+PerformForgetThreePIDPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveThreePIDAssociationPath)
defer span.Finish()
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
ctx context.Context,
request *api.PerformSaveThreePIDAssociationRequest,
response *struct{},
) error {
return httputil.CallInternalRPCAPI(
"PerformSaveThreePIDAssociation", h.apiURL+PerformSaveThreePIDAssociationPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -19,7 +19,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
const (
@ -33,11 +32,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenCreation(
request *api.PerformLoginTokenCreationRequest,
response *api.PerformLoginTokenCreationResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation")
defer span.Finish()
apiURL := h.apiURL + PerformLoginTokenCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformLoginTokenCreation", h.apiURL+PerformLoginTokenCreationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
@ -45,11 +43,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
request *api.PerformLoginTokenDeletionRequest,
response *api.PerformLoginTokenDeletionResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformLoginTokenDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformLoginTokenDeletion", h.apiURL+PerformLoginTokenDeletionPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryLoginToken(
@ -57,9 +54,8 @@ func (h *httpUserInternalAPI) QueryLoginToken(
request *api.QueryLoginTokenRequest,
response *api.QueryLoginTokenResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken")
defer span.Finish()
apiURL := h.apiURL + QueryLoginTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryLoginToken", h.apiURL+QueryLoginTokenPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -15,8 +15,6 @@
package inthttp
import (
"encoding/json"
"fmt"
"net/http"
"github.com/gorilla/mux"
@ -29,339 +27,134 @@ import (
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
addRoutesLoginToken(internalAPIMux, s)
internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformAccountCreationRequest{}
response := api.PerformAccountCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformAccountCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPasswordUpdatePath,
httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformPasswordUpdateRequest{}
response := api.PerformPasswordUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{}
response := api.PerformDeviceCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformLastSeenUpdatePath,
httputil.MakeInternalAPI("performLastSeenUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformLastSeenUpdateRequest{}
response := api.PerformLastSeenUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformLastSeenUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceUpdatePath,
httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceUpdateRequest{}
response := api.PerformDeviceUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceDeletionPath,
httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceDeletionRequest{}
response := api.PerformDeviceDeletionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformAccountDeactivationPath,
httputil.MakeInternalAPI("performAccountDeactivation", func(req *http.Request) util.JSONResponse {
request := api.PerformAccountDeactivationRequest{}
response := api.PerformAccountDeactivationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformAccountDeactivation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformOpenIDTokenCreationPath,
httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformOpenIDTokenCreationRequest{}
response := api.PerformOpenIDTokenCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryProfilePath,
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
request := api.QueryProfileRequest{}
response := api.QueryProfileResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryProfile(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryAccessTokenPath,
httputil.MakeInternalAPI("queryAccessToken", func(req *http.Request) util.JSONResponse {
request := api.QueryAccessTokenRequest{}
response := api.QueryAccessTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccessToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryDevicesPath,
httputil.MakeInternalAPI("queryDevices", func(req *http.Request) util.JSONResponse {
request := api.QueryDevicesRequest{}
response := api.QueryDevicesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryDevices(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryAccountDataPath,
httputil.MakeInternalAPI("queryAccountData", func(req *http.Request) util.JSONResponse {
request := api.QueryAccountDataRequest{}
response := api.QueryAccountDataResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccountData(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryDeviceInfosPath,
httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse {
request := api.QueryDeviceInfosRequest{}
response := api.QueryDeviceInfosResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QuerySearchProfilesPath,
httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse {
request := api.QuerySearchProfilesRequest{}
response := api.QuerySearchProfilesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryOpenIDTokenPath,
httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse {
request := api.QueryOpenIDTokenRequest{}
response := api.QueryOpenIDTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(InputAccountDataPath,
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
request := api.InputAccountDataRequest{}
response := api.InputAccountDataResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.InputAccountData(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryKeyBackupPath,
httputil.MakeInternalAPI("queryKeyBackup", func(req *http.Request) util.JSONResponse {
request := api.QueryKeyBackupRequest{}
response := api.QueryKeyBackupResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeyBackup(req.Context(), &request, &response)
if response.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", response.Error))
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformKeyBackupPath,
httputil.MakeInternalAPI("performKeyBackup", func(req *http.Request) util.JSONResponse {
request := api.PerformKeyBackupRequest{}
response := api.PerformKeyBackupResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.PerformKeyBackup(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryNotificationsPath,
httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse {
var request api.QueryNotificationsRequest
var response api.QueryNotificationsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryNotifications(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformAccountCreationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", s.PerformAccountCreation),
)
internalAPIMux.Handle(PerformPusherSetPath,
httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse {
request := api.PerformPusherSetRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPusherDeletionPath,
httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformPusherDeletionRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformPasswordUpdatePath,
httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", s.PerformPasswordUpdate),
)
internalAPIMux.Handle(QueryPushersPath,
httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse {
request := api.QueryPushersRequest{}
response := api.QueryPushersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryPushers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformDeviceCreationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", s.PerformDeviceCreation),
)
internalAPIMux.Handle(PerformPushRulesPutPath,
httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse {
request := api.PerformPushRulesPutRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformLastSeenUpdatePath,
httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", s.PerformLastSeenUpdate),
)
internalAPIMux.Handle(QueryPushRulesPath,
httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse {
request := api.QueryPushRulesRequest{}
response := api.QueryPushRulesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryPushRules(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformDeviceUpdatePath,
httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", s.PerformDeviceUpdate),
)
internalAPIMux.Handle(PerformSetAvatarURLPath,
httputil.MakeInternalAPI("performSetAvatarURL", func(req *http.Request) util.JSONResponse {
request := api.PerformSetAvatarURLRequest{}
response := api.PerformSetAvatarURLResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.SetAvatarURL(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformDeviceDeletionPath,
httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", s.PerformDeviceDeletion),
)
internalAPIMux.Handle(
PerformAccountDeactivationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", s.PerformAccountDeactivation),
)
internalAPIMux.Handle(
PerformOpenIDTokenCreationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", s.PerformOpenIDTokenCreation),
)
internalAPIMux.Handle(
QueryProfilePath,
httputil.MakeInternalRPCAPI("UserAPIQueryProfile", s.QueryProfile),
)
internalAPIMux.Handle(
QueryAccessTokenPath,
httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", s.QueryAccessToken),
)
internalAPIMux.Handle(
QueryDevicesPath,
httputil.MakeInternalRPCAPI("UserAPIQueryDevices", s.QueryDevices),
)
internalAPIMux.Handle(
QueryAccountDataPath,
httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", s.QueryAccountData),
)
internalAPIMux.Handle(
QueryDeviceInfosPath,
httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", s.QueryDeviceInfos),
)
internalAPIMux.Handle(
QuerySearchProfilesPath,
httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", s.QuerySearchProfiles),
)
internalAPIMux.Handle(
QueryOpenIDTokenPath,
httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", s.QueryOpenIDToken),
)
internalAPIMux.Handle(
InputAccountDataPath,
httputil.MakeInternalRPCAPI("UserAPIInputAccountData", s.InputAccountData),
)
internalAPIMux.Handle(
QueryKeyBackupPath,
httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", s.QueryKeyBackup),
)
internalAPIMux.Handle(
PerformKeyBackupPath,
httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", s.PerformKeyBackup),
)
internalAPIMux.Handle(
QueryNotificationsPath,
httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", s.QueryNotifications),
)
internalAPIMux.Handle(
PerformPusherSetPath,
httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", s.PerformPusherSet),
)
internalAPIMux.Handle(
PerformPusherDeletionPath,
httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", s.PerformPusherDeletion),
)
internalAPIMux.Handle(
QueryPushersPath,
httputil.MakeInternalRPCAPI("UserAPIQueryPushers", s.QueryPushers),
)
internalAPIMux.Handle(
PerformPushRulesPutPath,
httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", s.PerformPushRulesPut),
)
internalAPIMux.Handle(
QueryPushRulesPath,
httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", s.QueryPushRules),
)
internalAPIMux.Handle(
PerformSetAvatarURLPath,
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
)
// TODO: Look at the shape of this
internalAPIMux.Handle(QueryNumericLocalpartPath,
httputil.MakeInternalAPI("queryNumericLocalpart", func(req *http.Request) util.JSONResponse {
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
response := api.QueryNumericLocalpartResponse{}
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
return util.ErrorResponse(err)
@ -369,92 +162,39 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryAccountAvailabilityPath,
httputil.MakeInternalAPI("queryAccountAvailability", func(req *http.Request) util.JSONResponse {
request := api.QueryAccountAvailabilityRequest{}
response := api.QueryAccountAvailabilityResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccountAvailability(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryAccountAvailabilityPath,
httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", s.QueryAccountAvailability),
)
internalAPIMux.Handle(QueryAccountByPasswordPath,
httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse {
request := api.QueryAccountByPasswordRequest{}
response := api.QueryAccountByPasswordResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccountByPassword(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryAccountByPasswordPath,
httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", s.QueryAccountByPassword),
)
internalAPIMux.Handle(PerformSetDisplayNamePath,
httputil.MakeInternalAPI("performSetDisplayName", func(req *http.Request) util.JSONResponse {
request := api.PerformUpdateDisplayNameRequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.SetDisplayName(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
internalAPIMux.Handle(
PerformSetDisplayNamePath,
httputil.MakeInternalRPCAPI("UserAPISetDisplayName", s.SetDisplayName),
)
internalAPIMux.Handle(QueryLocalpartForThreePIDPath,
httputil.MakeInternalAPI("queryLocalpartForThreePID", func(req *http.Request) util.JSONResponse {
request := api.QueryLocalpartForThreePIDRequest{}
response := api.QueryLocalpartForThreePIDResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryLocalpartForThreePID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryLocalpartForThreePIDPath,
httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", s.QueryLocalpartForThreePID),
)
internalAPIMux.Handle(QueryThreePIDsForLocalpartPath,
httputil.MakeInternalAPI("queryThreePIDsForLocalpart", func(req *http.Request) util.JSONResponse {
request := api.QueryThreePIDsForLocalpartRequest{}
response := api.QueryThreePIDsForLocalpartResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryThreePIDsForLocalpart(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryThreePIDsForLocalpartPath,
httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", s.QueryThreePIDsForLocalpart),
)
internalAPIMux.Handle(PerformForgetThreePIDPath,
httputil.MakeInternalAPI("performForgetThreePID", func(req *http.Request) util.JSONResponse {
request := api.PerformForgetThreePIDRequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformForgetThreePID(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
internalAPIMux.Handle(
PerformForgetThreePIDPath,
httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", s.PerformForgetThreePID),
)
internalAPIMux.Handle(PerformSaveThreePIDAssociationPath,
httputil.MakeInternalAPI("performSaveThreePIDAssociation", func(req *http.Request) util.JSONResponse {
request := api.PerformSaveThreePIDAssociationRequest{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformSaveThreePIDAssociation(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
internalAPIMux.Handle(
PerformSaveThreePIDAssociationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", s.PerformSaveThreePIDAssociation),
)
}

View file

@ -15,54 +15,25 @@
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// addRoutesLoginToken adds routes for all login token API calls.
func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) {
internalAPIMux.Handle(PerformLoginTokenCreationPath,
httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformLoginTokenCreationRequest{}
response := api.PerformLoginTokenCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformLoginTokenCreationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", s.PerformLoginTokenCreation),
)
internalAPIMux.Handle(PerformLoginTokenDeletionPath,
httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformLoginTokenDeletionRequest{}
response := api.PerformLoginTokenDeletionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformLoginTokenDeletionPath,
httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", s.PerformLoginTokenDeletion),
)
internalAPIMux.Handle(QueryLoginTokenPath,
httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse {
request := api.QueryLoginTokenRequest{}
response := api.QueryLoginTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryLoginTokenPath,
httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", s.QueryLoginToken),
)
}

View file

@ -117,16 +117,20 @@ func TestQueryProfile(t *testing.T) {
},
}
runCases := func(testAPI api.UserInternalAPI) {
runCases := func(testAPI api.UserInternalAPI, http bool) {
mode := "monolith"
if http {
mode = "HTTP"
}
for _, tc := range testCases {
var gotRes api.QueryProfileResponse
gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes)
if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil {
t.Errorf("QueryProfile error, got %s want %s", gotErr, tc.wantErr)
t.Errorf("QueryProfile %s error, got %s want %s", mode, gotErr, tc.wantErr)
continue
}
if !reflect.DeepEqual(tc.wantRes, gotRes) {
t.Errorf("QueryProfile response got %+v want %+v", gotRes, tc.wantRes)
t.Errorf("QueryProfile %s response got %+v want %+v", mode, gotRes, tc.wantRes)
}
}
}
@ -140,10 +144,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil {
t.Fatalf("failed to create HTTP client")
}
runCases(httpAPI)
runCases(httpAPI, true)
})
t.Run("Monolith", func(t *testing.T) {
runCases(userAPI)
runCases(userAPI, false)
})
}