Pull latest dendrite fork into harmony (#252)

Latest dendrite main has changes for knockable rooms, and the fix for login crash. Pulled into dendrite fork. Rebased dendrite fork from dendrite main.

Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Co-authored-by: kegsay <kegan@matrix.org>
Co-authored-by: Tak Wai Wong <tak@hntlabs.com>
Co-authored-by: texuf <texuf.eth@gmail.com>
Co-authored-by: Brian Meek <brian@hntlabs.com>
Co-authored-by: Tak Wai Wong <takwaiw@gmail.com>
This commit is contained in:
Tak Wai Wong 2022-08-13 16:09:49 -07:00 committed by GitHub
parent 8a66a55771
commit ea0a8804d2
98 changed files with 2939 additions and 2818 deletions

View file

@ -19,10 +19,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ false }} # disable for now if: ${{ false }} # disable for now
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: 1.18 go-version: 1.18
@ -66,8 +66,12 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v3
with:
go-version: 1.18
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v3
# run go test with different go versions # run go test with different go versions
test: test:
@ -101,7 +105,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- uses: actions/cache@v3 - uses: actions/cache@v3
@ -133,7 +137,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install dependencies x86 - name: Install dependencies x86
@ -167,7 +171,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }} - name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install dependencies - name: Install dependencies
@ -208,7 +212,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: "1.18" go-version: "1.18"
- uses: actions/cache@v3 - uses: actions/cache@v3
@ -233,7 +237,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: "1.18" go-version: "1.18"
- uses: actions/cache@v3 - uses: actions/cache@v3

View file

@ -1,5 +1,30 @@
# Changelog # Changelog
## Dendrite 0.9.2 (2022-08-12)
### Features
* Dendrite now supports history visibility on the `/sync`, `/messages` and `/context` endpoints
* It should now be possible to view the history of a room in more cases (as opposed to limiting scrollback to the join event or defaulting to the restrictive `"join"` visibility rule as before)
* The default room version for newly created rooms is now room version 9
* New admin endpoint `/_dendrite/admin/resetPassword/{userID}` has been added, which replaces the `-reset-password` flag in `create-account`
* The `create-account` binary now uses shared secret registration over HTTP to create new accounts, which fixes a number of problems with account data and push rules not being configured correctly for new accounts
* The internal HTTP APIs for polylith deployments have been refactored for correctness and consistency
* The federation API will now automatically clean up some EDUs that have failed to send within a certain period of time
* The `/hierarchy` endpoint will now return potentially joinable rooms (contributed by [texuf](https://github.com/texuf))
* The user directory will now show or hide users correctly
### Fixes
* Send-to-device messages should no longer be incorrectly duplicated in `/sync`
* The federation sender will no longer create unnecessary destination queues as a result of a logic error
* A bug where database migrations may not execute properly when upgrading from older versions has been fixed
* A crash when failing to update user account data has been fixed
* A race condition when generating notification counts has been fixed
* A race condition when setting up NATS has been fixed (contributed by [brianathere](https://github.com/brianathere))
* Stale cache data for membership lazy-loading is now correctly invalidated when doing a complete sync
* Data races within user-interactive authentication have been fixed (contributed by [tak-hntlabs](https://github.com/tak-hntlabs))
## Dendrite 0.9.1 (2022-08-03) ## Dendrite 0.9.1 (2022-08-03)
### Fixes ### Fixes

View file

@ -80,7 +80,7 @@ $ ./bin/dendrite-monolith-server --tls-cert server.crt --tls-key server.key --co
# Create an user account (add -admin for an admin user). # Create an user account (add -admin for an admin user).
# Specify the localpart only, e.g. 'alice' for '@alice:domain.com' # Specify the localpart only, e.g. 'alice' for '@alice:domain.com'
$ ./bin/create-account --config dendrite.yaml -username alice $ ./bin/create-account --config dendrite.yaml --url http://localhost:8008 --username alice
``` ```
Then point your favourite Matrix client at `http://localhost:8008` or `https://localhost:8448`. Then point your favourite Matrix client at `http://localhost:8008` or `https://localhost:8448`.
@ -89,12 +89,12 @@ Then point your favourite Matrix client at `http://localhost:8008` or `https://l
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
updates with CI. As of August 2022 we're at around 83% CS API coverage and 95% Federation coverage, though check updates with CI. As of August 2022 we're at around 90% CS API coverage and 95% Federation coverage, though check
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
servers such as matrix.org reasonably well, although there are still some missing features (like Search). servers such as matrix.org reasonably well, although there are still some missing features (like Search).
We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather
than features that massive deployments may be interested in (User Directory, OpenID, Guests, Admin APIs, AS API). than features that massive deployments may be interested in (OpenID, Guests, Admin APIs, AS API).
This means Dendrite supports amongst others: This means Dendrite supports amongst others:
- Core room functionality (creating rooms, invites, auth rules) - Core room functionality (creating rooms, invites, auth rules)

View file

@ -7,7 +7,6 @@ import (
"github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/opentracing/opentracing-go"
) )
// HTTP paths for the internal HTTP APIs // HTTP paths for the internal HTTP APIs
@ -42,11 +41,10 @@ func (h *httpAppServiceQueryAPI) RoomAliasExists(
request *api.RoomAliasExistsRequest, request *api.RoomAliasExistsRequest,
response *api.RoomAliasExistsResponse, response *api.RoomAliasExistsResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists") return httputil.CallInternalRPCAPI(
defer span.Finish() "RoomAliasExists", h.appserviceURL+AppServiceRoomAliasExistsPath,
h.httpClient, ctx, request, response,
apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// UserIDExists implements AppServiceQueryAPI // UserIDExists implements AppServiceQueryAPI
@ -55,9 +53,8 @@ func (h *httpAppServiceQueryAPI) UserIDExists(
request *api.UserIDExistsRequest, request *api.UserIDExistsRequest,
response *api.UserIDExistsResponse, response *api.UserIDExistsResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceUserIDExists") return httputil.CallInternalRPCAPI(
defer span.Finish() "UserIDExists", h.appserviceURL+AppServiceUserIDExistsPath,
h.httpClient, ctx, request, response,
apiURL := h.appserviceURL + AppServiceUserIDExistsPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }

View file

@ -1,43 +1,20 @@
package inthttp package inthttp
import ( import (
"encoding/json"
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/util"
) )
// AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. // AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux.
func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) { func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle( internalAPIMux.Handle(
AppServiceRoomAliasExistsPath, AppServiceRoomAliasExistsPath,
httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists),
var request api.RoomAliasExistsRequest
var response api.RoomAliasExistsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := a.RoomAliasExists(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
AppServiceUserIDExistsPath, AppServiceUserIDExistsPath,
httputil.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists),
var request api.UserIDExistsRequest
var response api.UserIDExistsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := a.UserIDExists(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
} }

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -41,7 +42,7 @@ func LoginFromJSONReader(
userInteractiveAuth *UserInteractive, userInteractiveAuth *UserInteractive,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) (*Login, LoginCleanupFunc, *util.JSONResponse) { ) (*Login, LoginCleanupFunc, *util.JSONResponse) {
reqBytes, err := io.ReadAll(r) reqBytes, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
err := &util.JSONResponse{ err := &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

View file

@ -15,11 +15,13 @@
package jsonerror package jsonerror
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// MatrixError represents the "standard error response" in Matrix. // MatrixError represents the "standard error response" in Matrix.
@ -213,3 +215,15 @@ func NotTrusted(serverName string) *MatrixError {
Err: fmt.Sprintf("Untrusted server '%s'", serverName), Err: fmt.Sprintf("Untrusted server '%s'", serverName),
} }
} }
// InternalAPIError is returned when Dendrite failed to reach an internal API.
func InternalAPIError(ctx context.Context, err error) util.JSONResponse {
logrus.WithContext(ctx).WithError(err).Error("Error reaching an internal API")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: &MatrixError{
ErrCode: "M_INTERNAL_SERVER_ERROR",
Err: "Dendrite encountered an error reaching an internal API.",
},
}
}

View file

@ -1,23 +1,20 @@
package routing package routing
import ( import (
"encoding/json"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -30,13 +27,15 @@ func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserv
} }
} }
res := &roomserverAPI.PerformAdminEvacuateRoomResponse{} res := &roomserverAPI.PerformAdminEvacuateRoomResponse{}
rsAPI.PerformAdminEvacuateRoom( if err := rsAPI.PerformAdminEvacuateRoom(
req.Context(), req.Context(),
&roomserverAPI.PerformAdminEvacuateRoomRequest{ &roomserverAPI.PerformAdminEvacuateRoomRequest{
RoomID: roomID, RoomID: roomID,
}, },
res, res,
) ); err != nil {
return util.ErrorResponse(err)
}
if err := res.Error; err != nil { if err := res.Error; err != nil {
return err.JSONResponse() return err.JSONResponse()
} }
@ -48,13 +47,7 @@ func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserv
} }
} }
func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -66,14 +59,26 @@ func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserv
JSON: jsonerror.MissingArgument("Expecting user ID."), JSON: jsonerror.MissingArgument("Expecting user ID."),
} }
} }
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if domain != cfg.Matrix.ServerName {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("User ID must belong to this server."),
}
}
res := &roomserverAPI.PerformAdminEvacuateUserResponse{} res := &roomserverAPI.PerformAdminEvacuateUserResponse{}
rsAPI.PerformAdminEvacuateUser( if err := rsAPI.PerformAdminEvacuateUser(
req.Context(), req.Context(),
&roomserverAPI.PerformAdminEvacuateUserRequest{ &roomserverAPI.PerformAdminEvacuateUserRequest{
UserID: userID, UserID: userID,
}, },
res, res,
) ); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := res.Error; err != nil { if err := res.Error; err != nil {
return err.JSONResponse() return err.JSONResponse()
} }
@ -84,3 +89,52 @@ func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserv
}, },
} }
} }
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
localpart, ok := vars["localpart"]
if !ok {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting user localpart."),
}
}
request := struct {
Password string `json:"password"`
}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
}
}
if request.Password == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting non-empty password."),
}
}
updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart,
Password: request.Password,
LogoutDevices: true,
}
updateRes := &userapi.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to perform password update: " + err.Error()),
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct {
Updated bool `json:"password_updated"`
}{
Updated: updateRes.PasswordUpdated,
},
}
}

View file

@ -556,10 +556,12 @@ func createRoom(
if r.Visibility == "public" { if r.Visibility == "public" {
// expose this room in the published room list // expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse var pubRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: "public", Visibility: "public",
}, &pubRes) }, &pubRes); err != nil {
return jsonerror.InternalAPIError(ctx, err)
}
if pubRes.Error != nil { if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public") util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")

View file

@ -302,10 +302,12 @@ func SetVisibility(
} }
var publishRes roomserverAPI.PerformPublishResponse var publishRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: v.Visibility, Visibility: v.Visibility,
}, &publishRes) }, &publishRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if publishRes.Error != nil { if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed") util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse() return publishRes.Error.JSONResponse()

View file

@ -81,8 +81,9 @@ func JoinRoomByIDOrAlias(
done := make(chan util.JSONResponse, 1) done := make(chan util.JSONResponse, 1)
go func() { go func() {
defer close(done) defer close(done)
rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes) if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
if joinRes.Error != nil { done <- jsonerror.InternalAPIError(req.Context(), err)
} else if joinRes.Error != nil {
done <- joinRes.Error.JSONResponse() done <- joinRes.Error.JSONResponse()
} else { } else {
done <- util.JSONResponse{ done <- util.JSONResponse{

View file

@ -91,10 +91,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID, UserID: device.UserID,
Version: version, Version: version,
}, &queryResp) }, &queryResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if queryResp.Error != "" { if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
} }
@ -233,13 +235,15 @@ func GetBackupKeys(
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string, req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string,
) util.JSONResponse { ) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID, UserID: device.UserID,
Version: version, Version: version,
ReturnKeys: true, ReturnKeys: true,
KeysForRoomID: roomID, KeysForRoomID: roomID,
KeysForSessionID: sessionID, KeysForSessionID: sessionID,
}, &queryResp) }, &queryResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if queryResp.Error != "" { if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
} }

View file

@ -72,7 +72,9 @@ func UploadCrossSigningDeviceKeys(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
uploadReq.UserID = device.UserID uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) if err := keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := uploadRes.Error; err != nil { if err := uploadRes.Error; err != nil {
switch { switch {
@ -114,7 +116,9 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie
} }
uploadReq.UserID = device.UserID uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes) if err := keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := uploadRes.Error; err != nil { if err := uploadRes.Error; err != nil {
switch { switch {

View file

@ -62,7 +62,9 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devi
} }
var uploadRes api.PerformUploadKeysResponse var uploadRes api.PerformUploadKeysResponse
keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes) if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
return util.ErrorResponse(err)
}
if uploadRes.Error != nil { if uploadRes.Error != nil {
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys") util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -107,12 +109,14 @@ func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devic
return *resErr return *resErr
} }
queryRes := api.QueryKeysResponse{} queryRes := api.QueryKeysResponse{}
keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ if err := keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
UserID: device.UserID, UserID: device.UserID,
UserToDevices: r.DeviceKeys, UserToDevices: r.DeviceKeys,
Timeout: r.GetTimeout(), Timeout: r.GetTimeout(),
// TODO: Token? // TODO: Token?
}, &queryRes) }, &queryRes); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
@ -145,10 +149,12 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse {
return *resErr return *resErr
} }
claimRes := api.PerformClaimKeysResponse{} claimRes := api.PerformClaimKeysResponse{}
keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ if err := keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: r.OneTimeKeys, OneTimeKeys: r.OneTimeKeys,
Timeout: r.GetTimeout(), Timeout: r.GetTimeout(),
}, &claimRes) }, &claimRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if claimRes.Error != nil { if claimRes.Error != nil {
util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys") util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -54,7 +55,9 @@ func PeekRoomByIDOrAlias(
} }
// Ask the roomserver to perform the peek. // Ask the roomserver to perform the peek.
rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes) if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil {
return util.ErrorResponse(err)
}
if peekRes.Error != nil { if peekRes.Error != nil {
return peekRes.Error.JSONResponse() return peekRes.Error.JSONResponse()
} }
@ -89,7 +92,9 @@ func UnpeekRoomByID(
} }
unpeekRes := roomserverAPI.PerformUnpeekResponse{} unpeekRes := roomserverAPI.PerformUnpeekResponse{}
rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes) if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if unpeekRes.Error != nil { if unpeekRes.Error != nil {
return unpeekRes.Error.JSONResponse() return unpeekRes.Error.JSONResponse()
} }

View file

@ -144,17 +144,23 @@ func Setup(
} }
dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}",
httputil.MakeAuthAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateRoom(req, device, rsAPI) return AdminEvacuateRoom(req, cfg, device, rsAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}",
httputil.MakeAuthAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateUser(req, device, rsAPI) return AdminEvacuateUser(req, cfg, device, rsAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
// server notifications // server notifications
if cfg.Matrix.ServerNotices.Enabled { if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
@ -929,12 +935,12 @@ func Setup(
return SearchUserDirectory( return SearchUserDirectory(
req.Context(), req.Context(),
device, device,
userAPI,
rsAPI, rsAPI,
userDirectoryProvider, userDirectoryProvider,
cfg.Matrix.ServerName,
postContent.SearchString, postContent.SearchString,
postContent.Limit, postContent.Limit,
federation,
cfg.Matrix.ServerName,
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)

View file

@ -64,7 +64,9 @@ func UpgradeRoom(
} }
upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{} upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{}
rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp) if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if upgradeResp.Error != nil { if upgradeResp.Error != nil {
if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom { if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom {

View file

@ -18,10 +18,13 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"net/http"
"strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -34,12 +37,12 @@ type UserDirectoryResponse struct {
func SearchUserDirectory( func SearchUserDirectory(
ctx context.Context, ctx context.Context,
device *userapi.Device, device *userapi.Device,
userAPI userapi.ClientUserAPI,
rsAPI api.ClientRoomserverAPI, rsAPI api.ClientRoomserverAPI,
provider userapi.QuerySearchProfilesAPI, provider userapi.QuerySearchProfilesAPI,
serverName gomatrixserverlib.ServerName,
searchString string, searchString string,
limit int, limit int,
federation *gomatrixserverlib.FederationClient,
localServerName gomatrixserverlib.ServerName,
) util.JSONResponse { ) util.JSONResponse {
if limit < 10 { if limit < 10 {
limit = 10 limit = 10
@ -51,59 +54,74 @@ func SearchUserDirectory(
Limited: false, Limited: false,
} }
// First start searching local users. // Get users we share a room with
knownUsersReq := &api.QueryKnownUsersRequest{
UserID: device.UserID,
Limit: limit,
}
knownUsersRes := &api.QueryKnownUsersResponse{}
if err := rsAPI.QueryKnownUsers(ctx, knownUsersReq, knownUsersRes); err != nil && err != sql.ErrNoRows {
return util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err))
}
knownUsersLoop:
for _, profile := range knownUsersRes.Users {
if len(results) == limit {
response.Limited = true
break
}
userID := profile.UserID
// get the full profile of the local user
localpart, serverName, _ := gomatrixserverlib.SplitID('@', userID)
if serverName == localServerName {
userReq := &userapi.QuerySearchProfilesRequest{ userReq := &userapi.QuerySearchProfilesRequest{
SearchString: searchString, SearchString: localpart,
Limit: limit, Limit: limit,
} }
userRes := &userapi.QuerySearchProfilesResponse{} userRes := &userapi.QuerySearchProfilesResponse{}
if err := provider.QuerySearchProfiles(ctx, userReq, userRes); err != nil { if err := provider.QuerySearchProfiles(ctx, userReq, userRes); err != nil {
return util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err)) return util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err))
} }
for _, p := range userRes.Profiles {
for _, user := range userRes.Profiles { if strings.Contains(p.DisplayName, searchString) ||
strings.Contains(p.Localpart, searchString) {
profile.DisplayName = p.DisplayName
profile.AvatarURL = p.AvatarURL
results[userID] = profile
if len(results) == limit { if len(results) == limit {
response.Limited = true response.Limited = true
break break knownUsersLoop
}
}
} }
var userID string
if user.ServerName != "" {
userID = fmt.Sprintf("@%s:%s", user.Localpart, user.ServerName)
} else { } else {
userID = fmt.Sprintf("@%s:%s", user.Localpart, serverName) // If the username already contains the search string, don't bother hitting federation.
} // This will result in missing avatars and displaynames, but saves the federation roundtrip.
if _, ok := results[userID]; !ok { if strings.Contains(localpart, searchString) {
results[userID] = authtypes.FullyQualifiedProfile{ results[userID] = profile
UserID: userID,
DisplayName: user.DisplayName,
AvatarURL: user.AvatarURL,
}
}
}
// Then, if we have enough room left in the response,
// start searching for known users from joined rooms.
if len(results) <= limit {
stateReq := &api.QueryKnownUsersRequest{
UserID: device.UserID,
SearchString: searchString,
Limit: limit - len(results),
}
stateRes := &api.QueryKnownUsersResponse{}
if err := rsAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil && err != sql.ErrNoRows {
return util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err))
}
for _, user := range stateRes.Users {
if len(results) == limit { if len(results) == limit {
response.Limited = true response.Limited = true
break break knownUsersLoop
}
continue
}
// TODO: We should probably cache/store this
fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "")
if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound {
continue
}
}
}
if strings.Contains(fedProfile.DisplayName, searchString) {
profile.DisplayName = fedProfile.DisplayName
profile.AvatarURL = fedProfile.AvatarURL
results[userID] = profile
if len(results) == limit {
response.Limited = true
break knownUsersLoop
} }
if _, ok := results[user.UserID]; !ok {
results[user.UserID] = user
} }
} }
} }

View file

@ -15,20 +15,26 @@
package main package main
import ( import (
"context" "bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/term" "golang.org/x/term"
"github.com/matrix-org/dendrite/setup"
) )
const usage = `Usage: %s const usage = `Usage: %s
@ -46,8 +52,6 @@ Example:
# read password from stdin # read password from stdin
%s --config dendrite.yaml -username alice -passwordstdin < my.pass %s --config dendrite.yaml -username alice -passwordstdin < my.pass
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
# reset password for a user, can be used with a combination above to read the password
%s --config dendrite.yaml -reset-password -username alice -password foobarbaz
Arguments: Arguments:
@ -58,29 +62,34 @@ var (
password = flag.String("password", "", "The password to associate with the account") password = flag.String("password", "", "The password to associate with the account")
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
pwdLess = flag.Bool("passwordless", false, "Create a passwordless account, e.g. if only an accesstoken is required")
isAdmin = flag.Bool("admin", false, "Create an admin account") isAdmin = flag.Bool("admin", false, "Create an admin account")
resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") resetPassword = flag.Bool("reset-password", false, "Deprecated")
serverURL = flag.String("url", "https://localhost:8448", "The URL to connect to.")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
) )
var cl = http.Client{
Timeout: time.Second * 10,
Transport: http.DefaultTransport,
}
func main() { func main() {
name := os.Args[0] name := os.Args[0]
flag.Usage = func() { flag.Usage = func() {
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name) _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
flag.PrintDefaults() flag.PrintDefaults()
} }
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
if *resetPassword {
logrus.Fatalf("The reset-password flag has been replaced by the POST /_dendrite/admin/resetPassword/{localpart} admin API.")
}
if *username == "" { if *username == "" {
flag.Usage() flag.Usage()
os.Exit(1) os.Exit(1)
} }
if *pwdLess && *resetPassword {
logrus.Fatalf("Can not reset to an empty password, unable to login afterwards.")
}
if !validUsernameRegex.MatchString(*username) { if !validUsernameRegex.MatchString(*username) {
logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='")
os.Exit(1) os.Exit(1)
@ -90,67 +99,94 @@ func main() {
logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName))
} }
var pass string pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
var err error
if !*pwdLess {
pass, err = getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
if err != nil { if err != nil {
logrus.Fatalln(err) logrus.Fatalln(err)
} }
}
// avoid warning about open registration accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
cfg.ClientAPI.RegistrationDisabled = true
b := base.NewBaseDendrite(cfg, "")
defer b.Close() // nolint: errcheck
accountDB, err := storage.NewUserAPIDatabase(
b,
&cfg.UserAPI.AccountDatabase,
cfg.Global.ServerName,
cfg.UserAPI.BCryptCost,
cfg.UserAPI.OpenIDTokenLifetimeMS,
0, // TODO
cfg.Global.ServerNotices.LocalPart,
)
if err != nil {
logrus.WithError(err).Fatalln("Failed to connect to the database")
}
accType := api.AccountTypeUser
if *isAdmin {
accType = api.AccountTypeAdmin
}
available, err := accountDB.CheckAccountAvailability(context.Background(), *username)
if err != nil {
logrus.Fatalln("Unable check username existence.")
}
if *resetPassword {
if available {
logrus.Fatalln("Username could not be found.")
}
err = accountDB.SetPassword(context.Background(), *username, pass)
if err != nil {
logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error())
}
if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil {
logrus.Fatalf("Failed to remove all devices: %s", err.Error())
}
logrus.Infof("Updated password for user %s and invalidated all logins\n", *username)
return
}
if !available {
logrus.Fatalln("Username is already in use.")
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
if err != nil { if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error()) logrus.Fatalln("Failed to create the account:", err.Error())
} }
logrus.Infoln("Created account", *username) logrus.Infof("Created account: %s (AccessToken: %s)", *username, accessToken)
}
type sharedSecretRegistrationRequest struct {
User string `json:"username"`
Password string `json:"password"`
Nonce string `json:"nonce"`
MacStr string `json:"mac"`
Admin bool `json:"admin"`
}
func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, admin bool) (accesToken string, err error) {
registerURL := fmt.Sprintf("%s/_synapse/admin/v1/register", serverURL)
nonceReq, err := http.NewRequest(http.MethodGet, registerURL, nil)
if err != nil {
return "", fmt.Errorf("unable to create http request: %w", err)
}
nonceResp, err := cl.Do(nonceReq)
if err != nil {
return "", fmt.Errorf("unable to get nonce: %w", err)
}
body, err := io.ReadAll(nonceResp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}
defer nonceResp.Body.Close() // nolint: errcheck
nonce := gjson.GetBytes(body, "nonce").Str
adminStr := "notadmin"
if admin {
adminStr = "admin"
}
reg := sharedSecretRegistrationRequest{
User: localpart,
Password: password,
Nonce: nonce,
Admin: admin,
}
macStr, err := getRegisterMac(sharedSecret, nonce, localpart, password, adminStr)
if err != nil {
return "", err
}
reg.MacStr = macStr
js, err := json.Marshal(reg)
if err != nil {
return "", fmt.Errorf("unable to marshal json: %w", err)
}
registerReq, err := http.NewRequest(http.MethodPost, registerURL, bytes.NewBuffer(js))
if err != nil {
return "", fmt.Errorf("unable to create http request: %w", err)
}
regResp, err := cl.Do(registerReq)
if err != nil {
return "", fmt.Errorf("unable to create account: %w", err)
}
defer regResp.Body.Close() // nolint: errcheck
if regResp.StatusCode < 200 || regResp.StatusCode >= 300 {
body, _ = io.ReadAll(regResp.Body)
return "", fmt.Errorf(gjson.GetBytes(body, "error").Str)
}
r, _ := io.ReadAll(regResp.Body)
return gjson.GetBytes(r, "access_token").Str, nil
}
func getRegisterMac(sharedSecret, nonce, localpart, password, adminStr string) (string, error) {
joined := strings.Join([]string{nonce, localpart, password, adminStr}, "\x00")
mac := hmac.New(sha1.New, []byte(sharedSecret))
_, err := mac.Write([]byte(joined))
if err != nil {
return "", fmt.Errorf("unable to construct mac: %w", err)
}
regMac := mac.Sum(nil)
return hex.EncodeToString(regMac), nil
} }
func getPassword(password, pwdFile string, pwdStdin bool, r io.Reader) (string, error) { func getPassword(password, pwdFile string, pwdStdin bool, r io.Reader) (string, error) {

View file

@ -14,9 +14,8 @@ User accounts can be created on a Dendrite instance in a number of ways.
The `create-account` tool is built in the `bin` folder when building Dendrite with The `create-account` tool is built in the `bin` folder when building Dendrite with
the `build.sh` script. the `build.sh` script.
It uses the `dendrite.yaml` configuration file to connect to the Dendrite user database It uses the `dendrite.yaml` configuration file to connect to a running Dendrite instance and requires
and create the account entries directly. It can therefore be used even if Dendrite is not shared secret registration to be enabled as explained below.
running yet, as long as the database is up.
An example of using `create-account` to create a **normal account**: An example of using `create-account` to create a **normal account**:
@ -32,6 +31,13 @@ To create a new **admin account**, add the `-admin` flag:
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin ./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin
``` ```
By default `create-account` uses `https://localhost:8448` to connect to Dendrite, this can be overwritten using
the `-url` flag:
```bash
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -url http://localhost:8008
```
An example of using `create-account` when running in **Docker**, having found the `CONTAINERNAME` from `docker ps`: An example of using `create-account` when running in **Docker**, having found the `CONTAINERNAME` from `docker ps`:
```bash ```bash

View file

@ -13,19 +13,32 @@ without warning.
More endpoints will be added in the future. More endpoints will be added in the future.
## `/_dendrite/admin/evacuateRoom/{roomID}` ## GET `/_dendrite/admin/evacuateRoom/{roomID}`
This endpoint will instruct Dendrite to part all local users from the given `roomID` This endpoint will instruct Dendrite to part all local users from the given `roomID`
in the URL. It may take some time to complete. A JSON body will be returned containing in the URL. It may take some time to complete. A JSON body will be returned containing
the user IDs of all affected users. the user IDs of all affected users.
## `/_dendrite/admin/evacuateUser/{userID}` ## GET `/_dendrite/admin/evacuateUser/{userID}`
This endpoint will instruct Dendrite to part the given local `userID` in the URL from This endpoint will instruct Dendrite to part the given local `userID` in the URL from
all rooms which they are currently joined. A JSON body will be returned containing all rooms which they are currently joined. A JSON body will be returned containing
the room IDs of all affected rooms. the room IDs of all affected rooms.
## `/_synapse/admin/v1/register` ## POST `/_dendrite/admin/resetPassword/{localpart}`
Request body format:
```
{
"password": "new_password_here"
}
```
Reset the password of a local user. The `localpart` is the username only, i.e. if
the full user ID is `@alice:domain.com` then the local part is `alice`.
## GET `/_synapse/admin/v1/register`
Shared secret registration — please see the [user creation page](createusers) for Shared secret registration — please see the [user creation page](createusers) for
guidance on configuring and using this endpoint. guidance on configuring and using this endpoint.

View file

@ -110,7 +110,7 @@ type FederationClientError struct {
Blacklisted bool Blacklisted bool
} }
func (e *FederationClientError) Error() string { func (e FederationClientError) Error() string {
return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted) return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted)
} }

View file

@ -32,11 +32,12 @@ type fedRoomserverAPI struct {
} }
// PerformJoin will call this function // PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) error {
if f.inputRoomEvents == nil { if f.inputRoomEvents == nil {
return return nil
} }
f.inputRoomEvents(ctx, req, res) f.inputRoomEvents(ctx, req, res)
return nil
} }
// keychange consumer calls this // keychange consumer calls this

View file

@ -10,7 +10,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
) )
// HTTP paths for the internal HTTP API // HTTP paths for the internal HTTP API
@ -48,7 +47,11 @@ func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client,
if httpClient == nil { if httpClient == nil {
return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is <nil>") return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is <nil>")
} }
return &httpFederationInternalAPI{federationSenderURL, httpClient, cache}, nil return &httpFederationInternalAPI{
federationAPIURL: federationSenderURL,
httpClient: httpClient,
cache: cache,
}, nil
} }
type httpFederationInternalAPI struct { type httpFederationInternalAPI struct {
@ -63,11 +66,10 @@ func (h *httpFederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformLeave", h.federationAPIURL+FederationAPIPerformLeaveRequestPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIPerformLeaveRequestPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// Handle sending an invite to a remote server. // Handle sending an invite to a remote server.
@ -76,11 +78,10 @@ func (h *httpFederationInternalAPI) PerformInvite(
request *api.PerformInviteRequest, request *api.PerformInviteRequest,
response *api.PerformInviteResponse, response *api.PerformInviteResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInviteRequest") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformInvite", h.federationAPIURL+FederationAPIPerformInviteRequestPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIPerformInviteRequestPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// Handle starting a peek on a remote server. // Handle starting a peek on a remote server.
@ -89,11 +90,10 @@ func (h *httpFederationInternalAPI) PerformOutboundPeek(
request *api.PerformOutboundPeekRequest, request *api.PerformOutboundPeekRequest,
response *api.PerformOutboundPeekResponse, response *api.PerformOutboundPeekResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOutboundPeekRequest") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformOutboundPeek", h.federationAPIURL+FederationAPIPerformOutboundPeekRequestPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIPerformOutboundPeekRequestPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryJoinedHostServerNamesInRoom implements FederationInternalAPI // QueryJoinedHostServerNamesInRoom implements FederationInternalAPI
@ -102,11 +102,10 @@ func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest, request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse, response *api.QueryJoinedHostServerNamesInRoomResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryJoinedHostServerNamesInRoom", h.federationAPIURL+FederationAPIQueryJoinedHostServerNamesInRoomPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIQueryJoinedHostServerNamesInRoomPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// Handle an instruction to make_join & send_join with a remote server. // Handle an instruction to make_join & send_join with a remote server.
@ -115,12 +114,10 @@ func (h *httpFederationInternalAPI) PerformJoin(
request *api.PerformJoinRequest, request *api.PerformJoinRequest,
response *api.PerformJoinResponse, response *api.PerformJoinResponse,
) { ) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest") if err := httputil.CallInternalRPCAPI(
defer span.Finish() "PerformJoinRequest", h.federationAPIURL+FederationAPIPerformJoinRequestPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIPerformJoinRequestPath ); err != nil {
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.LastError = &gomatrix.HTTPError{ response.LastError = &gomatrix.HTTPError{
Message: err.Error(), Message: err.Error(),
Code: 0, Code: 0,
@ -135,11 +132,10 @@ func (h *httpFederationInternalAPI) PerformDirectoryLookup(
request *api.PerformDirectoryLookupRequest, request *api.PerformDirectoryLookupRequest,
response *api.PerformDirectoryLookupResponse, response *api.PerformDirectoryLookupResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformDirectoryLookup", h.federationAPIURL+FederationAPIPerformDirectoryLookupRequestPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIPerformDirectoryLookupRequestPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to. // Handle an instruction to broadcast an EDU to all servers in rooms we are joined to.
@ -148,101 +144,61 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU(
request *api.PerformBroadcastEDURequest, request *api.PerformBroadcastEDURequest,
response *api.PerformBroadcastEDUResponse, response *api.PerformBroadcastEDUResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformBroadcastEDU", h.federationAPIURL+FederationAPIPerformBroadcastEDUPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIPerformBroadcastEDUPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
type getUserDevices struct { type getUserDevices struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
UserID string UserID string
Res *gomatrixserverlib.RespUserDevices
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) GetUserDevices( func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetUserDevices") return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError](
defer span.Finish() "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient,
ctx, &getUserDevices{
var result gomatrixserverlib.RespUserDevices
request := getUserDevices{
S: s, S: s,
UserID: userID, UserID: userID,
} },
var response getUserDevices )
apiURL := h.federationAPIURL + FederationAPIGetUserDevicesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
} }
type claimKeys struct { type claimKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string OneTimeKeys map[string]map[string]string
Res *gomatrixserverlib.RespClaimKeys
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) ClaimKeys( func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "ClaimKeys") return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError](
defer span.Finish() "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient,
ctx, &claimKeys{
var result gomatrixserverlib.RespClaimKeys
request := claimKeys{
S: s, S: s,
OneTimeKeys: oneTimeKeys, OneTimeKeys: oneTimeKeys,
} },
var response claimKeys )
apiURL := h.federationAPIURL + FederationAPIClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
} }
type queryKeys struct { type queryKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Keys map[string][]string Keys map[string][]string
Res *gomatrixserverlib.RespQueryKeys
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) QueryKeys( func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys") return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError](
defer span.Finish() "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient,
ctx, &queryKeys{
var result gomatrixserverlib.RespQueryKeys
request := queryKeys{
S: s, S: s,
Keys: keys, Keys: keys,
} },
var response queryKeys )
apiURL := h.federationAPIURL + FederationAPIQueryKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
} }
type backfill struct { type backfill struct {
@ -250,32 +206,20 @@ type backfill struct {
RoomID string RoomID string
Limit int Limit int
EventIDs []string EventIDs []string
Res *gomatrixserverlib.Transaction
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) Backfill( func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill") return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError](
defer span.Finish() "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient,
ctx, &backfill{
request := backfill{
S: s, S: s,
RoomID: roomID, RoomID: roomID,
Limit: limit, Limit: limit,
EventIDs: eventIDs, EventIDs: eventIDs,
} },
var response backfill )
apiURL := h.federationAPIURL + FederationAPIBackfillPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
if response.Err != nil {
return gomatrixserverlib.Transaction{}, response.Err
}
return *response.Res, nil
} }
type lookupState struct { type lookupState struct {
@ -283,63 +227,39 @@ type lookupState struct {
RoomID string RoomID string
EventID string EventID string
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
Res *gomatrixserverlib.RespState
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) LookupState( func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) { ) (gomatrixserverlib.RespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState") return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError](
defer span.Finish() "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient,
ctx, &lookupState{
request := lookupState{
S: s, S: s,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
} },
var response lookupState )
apiURL := h.federationAPIURL + FederationAPILookupStatePath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespState{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespState{}, response.Err
}
return *response.Res, nil
} }
type lookupStateIDs struct { type lookupStateIDs struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
Res *gomatrixserverlib.RespStateIDs
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) LookupStateIDs( func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) { ) (gomatrixserverlib.RespStateIDs, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs") return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError](
defer span.Finish() "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient,
ctx, &lookupStateIDs{
request := lookupStateIDs{
S: s, S: s,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
} },
var response lookupStateIDs )
apiURL := h.federationAPIURL + FederationAPILookupStateIDsPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespStateIDs{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespStateIDs{}, response.Err
}
return *response.Res, nil
} }
type lookupMissingEvents struct { type lookupMissingEvents struct {
@ -347,64 +267,38 @@ type lookupMissingEvents struct {
RoomID string RoomID string
Missing gomatrixserverlib.MissingEvents Missing gomatrixserverlib.MissingEvents
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
Res struct {
Events []gomatrixserverlib.RawJSON `json:"events"`
}
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) LookupMissingEvents( func (h *httpFederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupMissingEvents") return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError](
defer span.Finish() "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient,
ctx, &lookupMissingEvents{
request := lookupMissingEvents{
S: s, S: s,
RoomID: roomID, RoomID: roomID,
Missing: missing, Missing: missing,
RoomVersion: roomVersion, RoomVersion: roomVersion,
} },
apiURL := h.federationAPIURL + FederationAPILookupMissingEventsPath )
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &request)
if err != nil {
return res, err
}
if request.Err != nil {
return res, request.Err
}
res.Events = request.Res.Events
return res, nil
} }
type getEvent struct { type getEvent struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
EventID string EventID string
Res *gomatrixserverlib.Transaction
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) GetEvent( func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent") return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError](
defer span.Finish() "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient,
ctx, &getEvent{
request := getEvent{
S: s, S: s,
EventID: eventID, EventID: eventID,
} },
var response getEvent )
apiURL := h.federationAPIURL + FederationAPIGetEventPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
if response.Err != nil {
return gomatrixserverlib.Transaction{}, response.Err
}
return *response.Res, nil
} }
type getEventAuth struct { type getEventAuth struct {
@ -412,135 +306,86 @@ type getEventAuth struct {
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
RoomID string RoomID string
EventID string EventID string
Res *gomatrixserverlib.RespEventAuth
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) GetEventAuth( func (h *httpFederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (gomatrixserverlib.RespEventAuth, error) { ) (gomatrixserverlib.RespEventAuth, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEventAuth") return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError](
defer span.Finish() "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient,
ctx, &getEventAuth{
request := getEventAuth{
S: s, S: s,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
} },
var response getEventAuth )
apiURL := h.federationAPIURL + FederationAPIGetEventAuthPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespEventAuth{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespEventAuth{}, response.Err
}
return *response.Res, nil
} }
func (h *httpFederationInternalAPI) QueryServerKeys( func (h *httpFederationInternalAPI) QueryServerKeys(
ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse, ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerKeys") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryServerKeys", h.federationAPIURL+FederationAPIQueryServerKeysPath,
h.httpClient, ctx, req, res,
apiURL := h.federationAPIURL + FederationAPIQueryServerKeysPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
type lookupServerKeys struct { type lookupServerKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
ServerKeys []gomatrixserverlib.ServerKeys
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) LookupServerKeys( func (h *httpFederationInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) { ) ([]gomatrixserverlib.ServerKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys") return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError](
defer span.Finish() "LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient,
ctx, &lookupServerKeys{
request := lookupServerKeys{
S: s, S: s,
KeyRequests: keyRequests, KeyRequests: keyRequests,
} },
var response lookupServerKeys )
apiURL := h.federationAPIURL + FederationAPILookupServerKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
if response.Err != nil {
return []gomatrixserverlib.ServerKeys{}, response.Err
}
return response.ServerKeys, nil
} }
type eventRelationships struct { type eventRelationships struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion RoomVer gomatrixserverlib.RoomVersion
Res gomatrixserverlib.MSC2836EventRelationshipsResponse
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) MSC2836EventRelationships( func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships") return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError](
defer span.Finish() "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient,
ctx, &eventRelationships{
request := eventRelationships{
S: s, S: s,
Req: r, Req: r,
RoomVer: roomVersion, RoomVer: roomVersion,
} },
var response eventRelationships )
apiURL := h.federationAPIURL + FederationAPIEventRelationshipsPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
} }
type spacesReq struct { type spacesReq struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
SuggestedOnly bool SuggestedOnly bool
RoomID string RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError
} }
func (h *httpFederationInternalAPI) MSC2946Spaces( func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError](
defer span.Finish() "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient,
ctx, &spacesReq{
request := spacesReq{
S: dst, S: dst,
SuggestedOnly: suggestedOnly, SuggestedOnly: suggestedOnly,
RoomID: roomID, RoomID: roomID,
} },
var response spacesReq )
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
} }
func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing {
@ -614,11 +459,10 @@ func (h *httpFederationInternalAPI) InputPublicKeys(
request *api.InputPublicKeysRequest, request *api.InputPublicKeysRequest,
response *api.InputPublicKeysResponse, response *api.InputPublicKeysResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey") return httputil.CallInternalRPCAPI(
defer span.Finish() "InputPublicKey", h.federationAPIURL+FederationAPIInputPublicKeyPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIInputPublicKeyPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpFederationInternalAPI) QueryPublicKeys( func (h *httpFederationInternalAPI) QueryPublicKeys(
@ -626,9 +470,8 @@ func (h *httpFederationInternalAPI) QueryPublicKeys(
request *api.QueryPublicKeysRequest, request *api.QueryPublicKeysRequest,
response *api.QueryPublicKeysResponse, response *api.QueryPublicKeysResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryPublicKeys", h.federationAPIURL+FederationAPIQueryPublicKeyPath,
h.httpClient, ctx, request, response,
apiURL := h.federationAPIURL + FederationAPIQueryPublicKeyPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }

View file

@ -1,12 +1,14 @@
package inthttp package inthttp
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -15,372 +17,180 @@ import (
func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIQueryJoinedHostServerNamesInRoomPath, FederationAPIQueryJoinedHostServerNamesInRoomPath,
httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom),
var request api.QueryJoinedHostServerNamesInRoomRequest
var response api.QueryJoinedHostServerNamesInRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath,
httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformJoinRequest
var response api.PerformJoinResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
intAPI.PerformJoin(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformLeaveRequestPath,
httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformLeaveRequest
var response api.PerformLeaveResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIPerformInviteRequestPath, FederationAPIPerformInviteRequestPath,
httputil.MakeInternalAPI("PerformInviteRequest", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite),
var request api.PerformInviteRequest
var response api.PerformInviteResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformInvite(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(
FederationAPIPerformLeaveRequestPath,
httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave),
)
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIPerformDirectoryLookupRequestPath, FederationAPIPerformDirectoryLookupRequestPath,
httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup),
var request api.PerformDirectoryLookupRequest
var response api.PerformDirectoryLookupResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIPerformBroadcastEDUPath, FederationAPIPerformBroadcastEDUPath,
httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
var request api.PerformBroadcastEDURequest
var response api.PerformBroadcastEDUResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath,
httputil.MakeInternalRPCAPI(
"FederationAPIPerformJoinRequest",
func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error {
intAPI.PerformJoin(ctx, req, res)
return nil
},
),
)
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIGetUserDevicesPath, FederationAPIGetUserDevicesPath,
httputil.MakeInternalAPI("GetUserDevices", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request getUserDevices "FederationAPIGetUserDevices",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID)
} return &res, federationClientError(err)
res, err := intAPI.GetUserDevices(req.Context(), request.S, request.UserID) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIClaimKeysPath, FederationAPIClaimKeysPath,
httputil.MakeInternalAPI("ClaimKeys", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request claimKeys "FederationAPIClaimKeys",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys)
} return &res, federationClientError(err)
res, err := intAPI.ClaimKeys(req.Context(), request.S, request.OneTimeKeys) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIQueryKeysPath, FederationAPIQueryKeysPath,
httputil.MakeInternalAPI("QueryKeys", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request queryKeys "FederationAPIQueryKeys",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.QueryKeys(ctx, req.S, req.Keys)
} return &res, federationClientError(err)
res, err := intAPI.QueryKeys(req.Context(), request.S, request.Keys) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIBackfillPath, FederationAPIBackfillPath,
httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request backfill "FederationAPIBackfill",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs)
} return &res, federationClientError(err)
res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPILookupStatePath, FederationAPILookupStatePath,
httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request lookupState "FederationAPILookupState",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion)
} return &res, federationClientError(err)
res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPILookupStateIDsPath, FederationAPILookupStateIDsPath,
httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request lookupStateIDs "FederationAPILookupStateIDs",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID)
} return &res, federationClientError(err)
res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPILookupMissingEventsPath, FederationAPILookupMissingEventsPath,
httputil.MakeInternalAPI("LookupMissingEvents", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request lookupMissingEvents "FederationAPILookupMissingEvents",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion)
} return &res, federationClientError(err)
res, err := intAPI.LookupMissingEvents(req.Context(), request.S, request.RoomID, request.Missing, request.RoomVersion) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
for _, event := range res.Events {
js, err := json.Marshal(event)
if err != nil {
return util.MessageResponse(http.StatusInternalServerError, err.Error())
}
request.Res.Events = append(request.Res.Events, js)
}
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIGetEventPath, FederationAPIGetEventPath,
httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request getEvent "FederationAPIGetEvent",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.GetEvent(ctx, req.S, req.EventID)
} return &res, federationClientError(err)
res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIGetEventAuthPath, FederationAPIGetEventAuthPath,
httputil.MakeInternalAPI("GetEventAuth", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request getEventAuth "FederationAPIGetEventAuth",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID)
} return &res, federationClientError(err)
res, err := intAPI.GetEventAuth(req.Context(), request.S, request.RoomVersion, request.RoomID, request.EventID) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIQueryServerKeysPath, FederationAPIQueryServerKeysPath,
httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys),
var request api.QueryServerKeysRequest
var response api.QueryServerKeysResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryServerKeys(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPILookupServerKeysPath, FederationAPILookupServerKeysPath,
httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request lookupServerKeys "FederationAPILookupServerKeys",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests)
} return &res, federationClientError(err)
res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.ServerKeys = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIEventRelationshipsPath, FederationAPIEventRelationshipsPath,
httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request eventRelationships "FederationAPIMSC2836EventRelationships",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer)
} return &res, federationClientError(err)
res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPISpacesSummaryPath, FederationAPISpacesSummaryPath,
httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse { httputil.MakeInternalProxyAPI(
var request spacesReq "FederationAPIMSC2946SpacesSummary",
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
return util.MessageResponse(http.StatusBadRequest, err.Error()) res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly)
} return &res, federationClientError(err)
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly) },
if err != nil { ),
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
) )
// TODO: Look at this shape
internalAPIMux.Handle(FederationAPIQueryPublicKeyPath, internalAPIMux.Handle(FederationAPIQueryPublicKeyPath,
httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryPublicKeysRequest{} request := api.QueryPublicKeysRequest{}
response := api.QueryPublicKeysResponse{} response := api.QueryPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -394,8 +204,10 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
// TODO: Look at this shape
internalAPIMux.Handle(FederationAPIInputPublicKeyPath, internalAPIMux.Handle(FederationAPIInputPublicKeyPath,
httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.InputPublicKeysRequest{} request := api.InputPublicKeysRequest{}
response := api.InputPublicKeysResponse{} response := api.InputPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -408,3 +220,18 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
}), }),
) )
} }
func federationClientError(err error) error {
switch ferr := err.(type) {
case nil:
return nil
case api.FederationClientError:
return &ferr
case *api.FederationClientError:
return ferr
default:
return &api.FederationClientError{
Err: err.Error(),
}
}
}

View file

@ -30,9 +30,11 @@ func GetUserDevices(
userID string, userID string,
) util.JSONResponse { ) util.JSONResponse {
var res keyapi.QueryDeviceMessagesResponse var res keyapi.QueryDeviceMessagesResponse
keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
UserID: userID, UserID: userID,
}, &res) }, &res); err != nil {
return util.ErrorResponse(err)
}
if res.Error != nil { if res.Error != nil {
util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed") util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -47,7 +49,9 @@ func GetUserDevices(
for _, dev := range res.Devices { for _, dev := range res.Devices {
sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID))
} }
keyAPI.QuerySignatures(req.Context(), sigReq, sigRes) if err := keyAPI.QuerySignatures(req.Context(), sigReq, sigRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
response := gomatrixserverlib.RespUserDevices{ response := gomatrixserverlib.RespUserDevices{
UserID: userID, UserID: userID,

View file

@ -392,7 +392,7 @@ func SendJoin(
// the room, so set SendAsServer to cfg.Matrix.ServerName // the room, so set SendAsServer to cfg.Matrix.ServerName
if !alreadyJoined { if !alreadyJoined {
var response api.InputRoomEventsResponse var response api.InputRoomEventsResponse
rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{ InputRoomEvents: []api.InputRoomEvent{
{ {
Kind: api.KindNew, Kind: api.KindNew,
@ -401,7 +401,9 @@ func SendJoin(
TransactionID: nil, TransactionID: nil,
}, },
}, },
}, &response) }, &response); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if response.ErrMsg != "" { if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed") util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed")
if response.NotAllowed { if response.NotAllowed {

View file

@ -19,7 +19,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/httputil" clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
@ -61,9 +61,11 @@ func QueryDeviceKeys(
} }
var queryRes api.QueryKeysResponse var queryRes api.QueryKeysResponse
keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ if err := keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
UserToDevices: qkr.DeviceKeys, UserToDevices: qkr.DeviceKeys,
}, &queryRes) }, &queryRes); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if queryRes.Error != nil { if queryRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys") util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -113,9 +115,11 @@ func ClaimOneTimeKeys(
} }
var claimRes api.PerformClaimKeysResponse var claimRes api.PerformClaimKeysResponse
keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ if err := keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: cor.OneTimeKeys, OneTimeKeys: cor.OneTimeKeys,
}, &claimRes) }, &claimRes); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if claimRes.Error != nil { if claimRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys") util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -184,7 +188,7 @@ func NotaryKeys(
) util.JSONResponse { ) util.JSONResponse {
if req == nil { if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
return *reqErr return *reqErr
} }
} }

View file

@ -277,7 +277,7 @@ func SendLeave(
// We are responsible for notifying other servers that the user has left // We are responsible for notifying other servers that the user has left
// the room, so set SendAsServer to cfg.Matrix.ServerName // the room, so set SendAsServer to cfg.Matrix.ServerName
var response api.InputRoomEventsResponse var response api.InputRoomEventsResponse
rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{ InputRoomEvents: []api.InputRoomEvent{
{ {
Kind: api.KindNew, Kind: api.KindNew,
@ -286,7 +286,9 @@ func SendLeave(
TransactionID: nil, TransactionID: nil,
}, },
}, },
}, &response) }, &response); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if response.ErrMsg != "" { if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed") util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed")

View file

@ -458,7 +458,9 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli
UserID: updatePayload.UserID, UserID: updatePayload.UserID,
} }
uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} uploadRes := &keyapi.PerformUploadDeviceKeysResponse{}
t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
return err
}
if uploadRes.Error != nil { if uploadRes.Error != nil {
return uploadRes.Error return uploadRes.Error
} }

View file

@ -64,11 +64,12 @@ func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) { ) error {
t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...)
for _, ire := range request.InputRoomEvents { for _, ire := range request.InputRoomEvents {
fmt.Println("InputRoomEvents: ", ire.Event.EventID()) fmt.Println("InputRoomEvents: ", ire.Event.EventID())
} }
return nil
} }
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.

12
go.mod
View file

@ -21,12 +21,12 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.13 github.com/mattn/go-sqlite3 v1.14.13
github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c
github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79
@ -34,7 +34,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.12.2 github.com/prometheus/client_golang v1.12.2
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.9.0
github.com/spruceid/siwe-go v0.2.0 github.com/spruceid/siwe-go v0.2.0
github.com/stretchr/objx v0.2.0 // indirect github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.7.1 github.com/stretchr/testify v1.7.1
@ -44,7 +44,7 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.3 github.com/yggdrasil-network/yggdrasil-go v0.4.3
go.uber.org/atomic v1.9.0 go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9 golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
golang.org/x/mobile v0.0.0-20220518205345-8578da9835fd golang.org/x/mobile v0.0.0-20220518205345-8578da9835fd
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e
@ -105,7 +105,7 @@ require (
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect golang.org/x/sys v0.0.0-20220731174439-a90be440212d // indirect
golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect
golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect
golang.org/x/tools v0.1.10 // indirect golang.org/x/tools v0.1.10 // indirect

24
go.sum
View file

@ -478,8 +478,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 h1:ZIPHFIPNDS9dmEbPEiJbNmyCGJtn9exfpLC7JOcn/bE= github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839 h1:QEFxKWH8PlEt3ZQKl31yJNAm8lvpNUwT51IMNTl9v1k=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY=
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -536,10 +536,10 @@ github.com/naoina/go-stringutil v0.1.0/go.mod h1:XJ2SJL9jCtBh+P9q5btrd/Ylo8XwT/h
github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416/go.mod h1:NBIhNtsFMo3G2szEBne+bO4gS192HuIYRqfvOWb4i1E= github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416/go.mod h1:NBIhNtsFMo3G2szEBne+bO4gS192HuIYRqfvOWb4i1E=
github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI=
github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee h1:vAtoZ+LW6eIUjkCWWwO1DZ6o16UGrVOG+ot/AkwejO8= github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c h1:U5qngWGZ7E/nQxz0544IpIEdKFUUaOJxQN2LHCYLGhg=
github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee/go.mod h1:3Yg3ApyQxPlAs1KKHKV5pobV5VtZk+TtOiUJx/iqkkg= github.com/nats-io/nats-server/v2 v2.8.5-0.20220811224153-d8d25d9b0b1c/go.mod h1:+f++B/5jpr71JATt7b5KCX+G7bt43iWx1OYWGkpE/Kk=
github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b h1:CE9wSYLvwq8aC/0+6zH8lhhtZYvJ9p8PzwvZeYgdBc0= github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995 h1:CUcSQR8jwa9//qNgN/t3tW53DObnTPQ/G/K+qnS7yRc=
github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/nats-io/nats.go v1.16.1-0.20220810192301-fb5ca2cbc995/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -669,8 +669,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
@ -771,8 +771,8 @@ golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -964,8 +964,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=

View file

@ -146,7 +146,7 @@ func (c *RistrettoCostedCachePartition[K, V]) Set(key K, value V) {
} }
type RistrettoCachePartition[K keyable, V any] struct { type RistrettoCachePartition[K keyable, V any] struct {
cache *ristretto.Cache cache *ristretto.Cache //nolint:all,unused
Prefix byte Prefix byte
Mutable bool Mutable bool
MaxAge time.Duration MaxAge time.Duration

View file

@ -19,19 +19,21 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/matrix-org/dendrite/userapi/api"
opentracing "github.com/opentracing/opentracing-go" opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/ext"
) )
// PostJSON performs a POST request with JSON on an internal HTTP API // PostJSON performs a POST request with JSON on an internal HTTP API.
func PostJSON( // The error will match the errtype if returned from the remote API, or
// will be a different type if there was a problem reaching the API.
func PostJSON[reqtype, restype any, errtype error](
ctx context.Context, span opentracing.Span, httpClient *http.Client, ctx context.Context, span opentracing.Span, httpClient *http.Client,
apiURL string, request, response interface{}, apiURL string, request *reqtype, response *restype,
) error { ) error {
jsonBytes, err := json.Marshal(request) jsonBytes, err := json.Marshal(request)
if err != nil { if err != nil {
@ -69,17 +71,23 @@ func PostJSON(
if err != nil { if err != nil {
return err return err
} }
if res.StatusCode != http.StatusOK { var body []byte
var errorBody struct { body, err = io.ReadAll(res.Body)
Message string `json:"message"` if err != nil {
return err
}
if res.StatusCode != http.StatusOK {
if len(body) == 0 {
return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL)
}
var reserr errtype
if err = json.Unmarshal(body, reserr); err != nil {
return fmt.Errorf("HTTP %d from %s", res.StatusCode, apiURL)
}
return reserr
}
if err = json.Unmarshal(body, response); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
} }
if _, ok := response.(*api.PerformKeyBackupResponse); ok { // TODO: remove this, once cross-boundary errors are a thing
return nil return nil
} }
if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil {
return fmt.Errorf("internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message)
}
return fmt.Errorf("internal API: %d from %s", res.StatusCode, apiURL)
}
return json.NewDecoder(res.Body).Decode(response)
}

View file

@ -25,6 +25,7 @@ import (
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go" opentracing "github.com/opentracing/opentracing-go"
@ -83,6 +84,23 @@ func MakeAuthAPI(
return MakeExternalAPI(metricsName, h) return MakeExternalAPI(metricsName, h)
} }
// MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be
// completed by a user that is a server administrator.
func MakeAdminAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
f func(*http.Request, *userapi.Device) util.JSONResponse,
) http.Handler {
return MakeAuthAPI(metricsName, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
return f(req, device)
})
}
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler. // MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are called from the internet. // This is used for APIs that are called from the internet.
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {

View file

@ -0,0 +1,93 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httputil
import (
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
type InternalAPIError struct {
Type string
Message string
}
func (e InternalAPIError) Error() string {
return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message)
}
func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler {
return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse {
var request reqtype
var response restype
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := f(req.Context(), &request, &response); err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: &InternalAPIError{
Type: reflect.TypeOf(err).String(),
Message: fmt.Sprintf("%s", err),
},
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: &response,
}
})
}
func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler {
return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse {
var request reqtype
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
response, err := f(req.Context(), &request)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: err,
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: response,
}
})
}
func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Client, ctx context.Context, request *reqtype, response *restype) error {
span, ctx := opentracing.StartSpanFromContext(ctx, name)
defer span.Finish()
return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response)
}
func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, name)
defer span.Finish()
var response restype
return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response)
}

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 9 VersionMinor = 9
VersionPatch = 1 VersionPatch = 2
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -38,32 +38,32 @@ type KeyInternalAPI interface {
// API functions required by the clientapi // API functions required by the clientapi
type ClientKeyAPI interface { type ClientKeyAPI interface {
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
// PerformClaimKeys claims one-time keys for use in pre-key messages // PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
} }
// API functions required by the userapi // API functions required by the userapi
type UserKeyAPI interface { type UserKeyAPI interface {
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error
} }
// API functions required by the syncapi // API functions required by the syncapi
type SyncKeyAPI interface { type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
} }
type FederationKeyAPI interface { type FederationKeyAPI interface {
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
} }
// KeyError is returned if there was a problem performing/querying the server // KeyError is returned if there was a problem performing/querying the server

View file

@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos
} }
// nolint:gocyclo // nolint:gocyclo
func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
// Find the keys to store. // Find the keys to store.
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
toStore := types.CrossSigningKeyMap{} toStore := types.CrossSigningKeyMap{}
@ -115,7 +115,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "Master key sanity check failed: " + err.Error(), Err: "Master key sanity check failed: " + err.Error(),
IsInvalidParam: true, IsInvalidParam: true,
} }
return return nil
} }
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
@ -131,7 +131,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "Self-signing key sanity check failed: " + err.Error(), Err: "Self-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true, IsInvalidParam: true,
} }
return return nil
} }
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
@ -146,7 +146,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "User-signing key sanity check failed: " + err.Error(), Err: "User-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true, IsInvalidParam: true,
} }
return return nil
} }
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
@ -161,7 +161,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "No keys were supplied in the request", Err: "No keys were supplied in the request",
IsMissingParam: true, IsMissingParam: true,
} }
return return nil
} }
// We can't have a self-signing or user-signing key without a master // We can't have a self-signing or user-signing key without a master
@ -174,7 +174,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(), Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
} }
return return nil
} }
// If we still can't find a master key for the user then stop the upload. // If we still can't find a master key for the user then stop the upload.
@ -185,7 +185,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "No master key was found", Err: "No master key was found",
IsMissingParam: true, IsMissingParam: true,
} }
return return nil
} }
} }
@ -212,7 +212,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
} }
} }
if !changed { if !changed {
return return nil
} }
// Store the keys. // Store the keys.
@ -220,7 +220,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
} }
return return nil
} }
// Now upload any signatures that were included with the keys. // Now upload any signatures that were included with the keys.
@ -238,7 +238,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
} }
return return nil
} }
} }
} }
@ -255,17 +255,18 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
update.SelfSigningKey = &ssk update.SelfSigningKey = &ssk
} }
if update.MasterKey == nil && update.SelfSigningKey == nil { if update.MasterKey == nil && update.SelfSigningKey == nil {
return return nil
} }
if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
} }
return return nil
} }
return nil
} }
func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) { func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
// Before we do anything, we need the master and self-signing keys for this user. // Before we do anything, we need the master and self-signing keys for this user.
// Then we can verify the signatures make sense. // Then we can verify the signatures make sense.
queryReq := &api.QueryKeysRequest{ queryReq := &api.QueryKeysRequest{
@ -276,7 +277,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
for userID := range req.Signatures { for userID := range req.Signatures {
queryReq.UserToDevices[userID] = []string{} queryReq.UserToDevices[userID] = []string{}
} }
a.QueryKeys(ctx, queryReq, queryRes) _ = a.QueryKeys(ctx, queryReq, queryRes)
selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
@ -322,14 +323,14 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processSelfSignatures: %s", err), Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
} }
return return nil
} }
if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processOtherSignatures: %s", err), Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
} }
return return nil
} }
// Finally, generate a notification that we updated the signatures. // Finally, generate a notification that we updated the signatures.
@ -345,9 +346,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
} }
return return nil
} }
} }
return nil
} }
func (a *KeyInternalAPI) processSelfSignatures( func (a *KeyInternalAPI) processSelfSignatures(
@ -520,7 +522,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
} }
} }
func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) { func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
for targetUserID, forTargetUser := range req.TargetIDs { for targetUserID, forTargetUser := range req.TargetIDs {
keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
@ -559,7 +561,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
} }
return return nil
} }
for sourceUserID, forSourceUser := range sigMap { for sourceUserID, forSourceUser := range sigMap {
@ -581,4 +583,5 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
} }
} }
} }
return nil
} }

View file

@ -119,7 +119,7 @@ type DeviceListUpdaterDatabase interface {
} }
type DeviceListUpdaterAPI interface { type DeviceListUpdaterAPI interface {
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error
} }
// KeyChangeProducer is the interface for producers.KeyChange useful for testing. // KeyChangeProducer is the interface for producers.KeyChange useful for testing.
@ -421,7 +421,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
uploadReq.SelfSigningKey = *res.SelfSigningKey uploadReq.SelfSigningKey = *res.SelfSigningKey
} }
} }
u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
} }
err = u.updateDeviceList(&res) err = u.updateDeviceList(&res)
if err != nil { if err != nil {

View file

@ -113,8 +113,8 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys
type mockDeviceListUpdaterAPI struct { type mockDeviceListUpdaterAPI struct {
} }
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
return nil
} }
type roundTripper struct { type roundTripper struct {

View file

@ -48,18 +48,20 @@ func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
a.UserAPI = i a.UserAPI = i
} }
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: err.Error(), Err: err.Error(),
} }
return nil
} }
res.Offset = latest res.Offset = latest
res.UserIDs = userIDs res.UserIDs = userIDs
return nil
} }
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
res.KeyErrors = make(map[string]map[string]*api.KeyError) res.KeyErrors = make(map[string]map[string]*api.KeyError)
if len(req.DeviceKeys) > 0 { if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res) a.uploadLocalDeviceKeys(ctx, req, res)
@ -67,9 +69,10 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
if len(req.OneTimeKeys) > 0 { if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res)
} }
return nil
} }
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{}) res.Failures = make(map[string]interface{})
// wrap request map in a top-level by-domain map // wrap request map in a top-level by-domain map
@ -113,6 +116,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
if len(domainToDeviceKeys) > 0 { if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
} }
return nil
} }
func (a *KeyInternalAPI) claimRemoteKeys( func (a *KeyInternalAPI) claimRemoteKeys(
@ -172,32 +176,34 @@ func (a *KeyInternalAPI) claimRemoteKeys(
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
} }
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) { func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err), Err: fmt.Sprintf("Failed to delete device keys: %s", err),
} }
} }
return nil
} }
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) { func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err), Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
} }
return return nil
} }
res.Count = *count res.Count = *count
return nil
} }
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err), Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
} }
return return nil
} }
maxStreamID := int64(0) maxStreamID := int64(0)
for _, m := range msgs { for _, m := range msgs {
@ -215,10 +221,11 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
} }
res.Devices = result res.Devices = result
res.StreamID = maxStreamID res.StreamID = maxStreamID
return nil
} }
// nolint:gocyclo // nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@ -244,7 +251,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err), Err: fmt.Sprintf("failed to query local device keys: %s", err),
} }
return return nil
} }
// pull out display names after we have the keys so we handle wildcards correctly // pull out display names after we have the keys so we handle wildcards correctly
@ -318,7 +325,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// Stop executing the function if the context was canceled/the deadline was exceeded, // Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context. // as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return return nil
} }
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue continue
@ -344,7 +351,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// Stop executing the function if the context was canceled/the deadline was exceeded, // Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context. // as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return return nil
} }
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue continue
@ -372,6 +379,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
} }
} }
return nil
} }
func (a *KeyInternalAPI) remoteKeysFromDatabase( func (a *KeyInternalAPI) remoteKeysFromDatabase(

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
) )
// HTTP paths for the internal HTTP APIs // HTTP paths for the internal HTTP APIs
@ -68,168 +67,108 @@ func (h *httpKeyInternalAPI) PerformClaimKeys(
ctx context.Context, ctx context.Context,
request *api.PerformClaimKeysRequest, request *api.PerformClaimKeysRequest,
response *api.PerformClaimKeysResponse, response *api.PerformClaimKeysResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformClaimKeys", h.apiURL+PerformClaimKeysPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformClaimKeysPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) PerformDeleteKeys( func (h *httpKeyInternalAPI) PerformDeleteKeys(
ctx context.Context, ctx context.Context,
request *api.PerformDeleteKeysRequest, request *api.PerformDeleteKeysRequest,
response *api.PerformDeleteKeysResponse, response *api.PerformDeleteKeysResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformDeleteKeys", h.apiURL+PerformDeleteKeysPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformClaimKeysPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) PerformUploadKeys( func (h *httpKeyInternalAPI) PerformUploadKeys(
ctx context.Context, ctx context.Context,
request *api.PerformUploadKeysRequest, request *api.PerformUploadKeysRequest,
response *api.PerformUploadKeysResponse, response *api.PerformUploadKeysResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadKeys") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformUploadKeys", h.apiURL+PerformUploadKeysPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformUploadKeysPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) QueryKeys( func (h *httpKeyInternalAPI) QueryKeys(
ctx context.Context, ctx context.Context,
request *api.QueryKeysRequest, request *api.QueryKeysRequest,
response *api.QueryKeysResponse, response *api.QueryKeysResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryKeys", h.apiURL+QueryKeysPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QueryKeysPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) QueryOneTimeKeys( func (h *httpKeyInternalAPI) QueryOneTimeKeys(
ctx context.Context, ctx context.Context,
request *api.QueryOneTimeKeysRequest, request *api.QueryOneTimeKeysRequest,
response *api.QueryOneTimeKeysResponse, response *api.QueryOneTimeKeysResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryOneTimeKeys", h.apiURL+QueryOneTimeKeysPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QueryOneTimeKeysPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) QueryDeviceMessages( func (h *httpKeyInternalAPI) QueryDeviceMessages(
ctx context.Context, ctx context.Context,
request *api.QueryDeviceMessagesRequest, request *api.QueryDeviceMessagesRequest,
response *api.QueryDeviceMessagesResponse, response *api.QueryDeviceMessagesResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceMessages") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryDeviceMessages", h.apiURL+QueryDeviceMessagesPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QueryDeviceMessagesPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) QueryKeyChanges( func (h *httpKeyInternalAPI) QueryKeyChanges(
ctx context.Context, ctx context.Context,
request *api.QueryKeyChangesRequest, request *api.QueryKeyChangesRequest,
response *api.QueryKeyChangesResponse, response *api.QueryKeyChangesResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryKeyChanges", h.apiURL+QueryKeyChangesPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QueryKeyChangesPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) PerformUploadDeviceKeys( func (h *httpKeyInternalAPI) PerformUploadDeviceKeys(
ctx context.Context, ctx context.Context,
request *api.PerformUploadDeviceKeysRequest, request *api.PerformUploadDeviceKeysRequest,
response *api.PerformUploadDeviceKeysResponse, response *api.PerformUploadDeviceKeysResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceKeys") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformUploadDeviceKeys", h.apiURL+PerformUploadDeviceKeysPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformUploadDeviceKeysPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures( func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures(
ctx context.Context, ctx context.Context,
request *api.PerformUploadDeviceSignaturesRequest, request *api.PerformUploadDeviceSignaturesRequest,
response *api.PerformUploadDeviceSignaturesResponse, response *api.PerformUploadDeviceSignaturesResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceSignatures") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformUploadDeviceSignatures", h.apiURL+PerformUploadDeviceSignaturesPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformUploadDeviceSignaturesPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }
func (h *httpKeyInternalAPI) QuerySignatures( func (h *httpKeyInternalAPI) QuerySignatures(
ctx context.Context, ctx context.Context,
request *api.QuerySignaturesRequest, request *api.QuerySignaturesRequest,
response *api.QuerySignaturesResponse, response *api.QuerySignaturesResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySignatures") return httputil.CallInternalRPCAPI(
defer span.Finish() "QuerySignatures", h.apiURL+QuerySignaturesPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QuerySignaturesPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
} }

View file

@ -15,124 +15,59 @@
package inthttp package inthttp
import ( import (
"encoding/json"
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/util"
) )
func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
internalAPIMux.Handle(PerformClaimKeysPath, internalAPIMux.Handle(
httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse { PerformClaimKeysPath,
request := api.PerformClaimKeysRequest{} httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys),
response := api.PerformClaimKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformClaimKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformDeleteKeysPath,
httputil.MakeInternalAPI("performDeleteKeys", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformDeleteKeysRequest{} PerformDeleteKeysPath,
response := api.PerformDeleteKeysResponse{} httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformDeleteKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformUploadKeysPath,
httputil.MakeInternalAPI("performUploadKeys", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformUploadKeysRequest{} PerformUploadKeysPath,
response := api.PerformUploadKeysResponse{} httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformUploadDeviceKeysPath,
httputil.MakeInternalAPI("performUploadDeviceKeys", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformUploadDeviceKeysRequest{} PerformUploadDeviceKeysPath,
response := api.PerformUploadDeviceKeysResponse{} httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadDeviceKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformUploadDeviceSignaturesPath,
httputil.MakeInternalAPI("performUploadDeviceSignatures", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformUploadDeviceSignaturesRequest{} PerformUploadDeviceSignaturesPath,
response := api.PerformUploadDeviceSignaturesResponse{} httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadDeviceSignatures(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryKeysPath,
httputil.MakeInternalAPI("queryKeys", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryKeysRequest{} QueryKeysPath,
response := api.QueryKeysResponse{} httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryOneTimeKeysPath,
httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryOneTimeKeysRequest{} QueryOneTimeKeysPath,
response := api.QueryOneTimeKeysResponse{} httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryOneTimeKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryDeviceMessagesPath,
httputil.MakeInternalAPI("queryDeviceMessages", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryDeviceMessagesRequest{} QueryDeviceMessagesPath,
response := api.QueryDeviceMessagesResponse{} httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryDeviceMessages(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryKeyChangesPath,
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryKeyChangesRequest{} QueryKeyChangesPath,
response := api.QueryKeyChangesResponse{} httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeyChanges(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QuerySignaturesPath,
httputil.MakeInternalAPI("querySignatures", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QuerySignaturesRequest{} QuerySignaturesPath,
response := api.QuerySignaturesResponse{} httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QuerySignatures(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
} }

View file

@ -40,7 +40,7 @@ type InputRoomEventsAPI interface {
ctx context.Context, ctx context.Context,
req *InputRoomEventsRequest, req *InputRoomEventsRequest,
res *InputRoomEventsResponse, res *InputRoomEventsResponse,
) ) error
} }
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.
@ -97,6 +97,14 @@ type SyncRoomserverAPI interface {
req *PerformBackfillRequest, req *PerformBackfillRequest,
res *PerformBackfillResponse, res *PerformBackfillResponse,
) error ) error
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to a slice of gomatrixserverlib.HeaderedEvent.
QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error
} }
type AppserviceRoomserverAPI interface { type AppserviceRoomserverAPI interface {
@ -139,15 +147,15 @@ type ClientRoomserverAPI interface {
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
// PerformRoomUpgrade upgrades a room to a newer version // PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error
// PerformForget forgets a rooms history for a specific user // PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error
@ -158,7 +166,7 @@ type UserRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
} }
type FederationRoomserverAPI interface { type FederationRoomserverAPI interface {

View file

@ -35,9 +35,10 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents(
ctx context.Context, ctx context.Context,
req *InputRoomEventsRequest, req *InputRoomEventsRequest,
res *InputRoomEventsResponse, res *InputRoomEventsResponse,
) { ) error {
t.Impl.InputRoomEvents(ctx, req, res) err := t.Impl.InputRoomEvents(ctx, req, res)
util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformInvite( func (t *RoomserverInternalAPITrace) PerformInvite(
@ -45,44 +46,49 @@ func (t *RoomserverInternalAPITrace) PerformInvite(
req *PerformInviteRequest, req *PerformInviteRequest,
res *PerformInviteResponse, res *PerformInviteResponse,
) error { ) error {
util.GetLogger(ctx).Infof("PerformInvite req=%+v res=%+v", js(req), js(res)) err := t.Impl.PerformInvite(ctx, req, res)
return t.Impl.PerformInvite(ctx, req, res) util.GetLogger(ctx).WithError(err).Infof("PerformInvite req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformPeek( func (t *RoomserverInternalAPITrace) PerformPeek(
ctx context.Context, ctx context.Context,
req *PerformPeekRequest, req *PerformPeekRequest,
res *PerformPeekResponse, res *PerformPeekResponse,
) { ) error {
t.Impl.PerformPeek(ctx, req, res) err := t.Impl.PerformPeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPeek req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("PerformPeek req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformUnpeek( func (t *RoomserverInternalAPITrace) PerformUnpeek(
ctx context.Context, ctx context.Context,
req *PerformUnpeekRequest, req *PerformUnpeekRequest,
res *PerformUnpeekResponse, res *PerformUnpeekResponse,
) { ) error {
t.Impl.PerformUnpeek(ctx, req, res) err := t.Impl.PerformUnpeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformRoomUpgrade( func (t *RoomserverInternalAPITrace) PerformRoomUpgrade(
ctx context.Context, ctx context.Context,
req *PerformRoomUpgradeRequest, req *PerformRoomUpgradeRequest,
res *PerformRoomUpgradeResponse, res *PerformRoomUpgradeResponse,
) { ) error {
t.Impl.PerformRoomUpgrade(ctx, req, res) err := t.Impl.PerformRoomUpgrade(ctx, req, res)
util.GetLogger(ctx).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformJoin( func (t *RoomserverInternalAPITrace) PerformJoin(
ctx context.Context, ctx context.Context,
req *PerformJoinRequest, req *PerformJoinRequest,
res *PerformJoinResponse, res *PerformJoinResponse,
) { ) error {
t.Impl.PerformJoin(ctx, req, res) err := t.Impl.PerformJoin(ctx, req, res)
util.GetLogger(ctx).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformLeave( func (t *RoomserverInternalAPITrace) PerformLeave(
@ -99,27 +105,30 @@ func (t *RoomserverInternalAPITrace) PerformPublish(
ctx context.Context, ctx context.Context,
req *PerformPublishRequest, req *PerformPublishRequest,
res *PerformPublishResponse, res *PerformPublishResponse,
) { ) error {
t.Impl.PerformPublish(ctx, req, res) err := t.Impl.PerformPublish(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom( func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom(
ctx context.Context, ctx context.Context,
req *PerformAdminEvacuateRoomRequest, req *PerformAdminEvacuateRoomRequest,
res *PerformAdminEvacuateRoomResponse, res *PerformAdminEvacuateRoomResponse,
) { ) error {
t.Impl.PerformAdminEvacuateRoom(ctx, req, res) err := t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser( func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser(
ctx context.Context, ctx context.Context,
req *PerformAdminEvacuateUserRequest, req *PerformAdminEvacuateUserRequest,
res *PerformAdminEvacuateUserResponse, res *PerformAdminEvacuateUserResponse,
) { ) error {
t.Impl.PerformAdminEvacuateUser(ctx, req, res) err := t.Impl.PerformAdminEvacuateUser(ctx, req, res)
util.GetLogger(ctx).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformInboundPeek( func (t *RoomserverInternalAPITrace) PerformInboundPeek(
@ -128,7 +137,7 @@ func (t *RoomserverInternalAPITrace) PerformInboundPeek(
res *PerformInboundPeekResponse, res *PerformInboundPeekResponse,
) error { ) error {
err := t.Impl.PerformInboundPeek(ctx, req, res) err := t.Impl.PerformInboundPeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).WithError(err).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res))
return err return err
} }
@ -373,6 +382,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed(
return err return err
} }
func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error {
err := t.Impl.QueryMembershipAtEvent(ctx, request, response)
util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)
if err != nil { if err != nil {

View file

@ -427,3 +427,17 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
} }
return nil return nil
} }
// QueryMembershipAtEventRequest requests the membership events for a user
// for a list of eventIDs.
type QueryMembershipAtEventRequest struct {
RoomID string
EventIDs []string
UserID string
}
// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
type QueryMembershipAtEventResponse struct {
// Memberships is a map from eventID to a list of events (if any).
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
}

View file

@ -90,7 +90,9 @@ func SendInputRoomEvents(
Asynchronous: async, Asynchronous: async,
} }
var response InputRoomEventsResponse var response InputRoomEventsResponse
rsAPI.InputRoomEvents(ctx, &request, &response) if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil {
return err
}
return response.Err() return response.Err()
} }

View file

@ -208,6 +208,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState) return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
} }
func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
}
func LoadEvents( func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID, ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {

View file

@ -337,18 +337,18 @@ func (r *Inputer) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) { ) error {
// Queue up the event into the roomserver. // Queue up the event into the roomserver.
replySub, err := r.queueInputRoomEvents(ctx, request) replySub, err := r.queueInputRoomEvents(ctx, request)
if err != nil { if err != nil {
response.ErrMsg = err.Error() response.ErrMsg = err.Error()
return return nil
} }
// If we aren't waiting for synchronous responses then we can // If we aren't waiting for synchronous responses then we can
// give up here, there is nothing further to do. // give up here, there is nothing further to do.
if replySub == nil { if replySub == nil {
return return nil
} }
// Otherwise, we'll want to sit and wait for the responses // Otherwise, we'll want to sit and wait for the responses
@ -360,12 +360,14 @@ func (r *Inputer) InputRoomEvents(
msg, err := replySub.NextMsgWithContext(ctx) msg, err := replySub.NextMsgWithContext(ctx)
if err != nil { if err != nil {
response.ErrMsg = err.Error() response.ErrMsg = err.Error()
return return nil
} }
if len(msg.Data) > 0 { if len(msg.Data) > 0 {
response.ErrMsg = string(msg.Data) response.ErrMsg = string(msg.Data)
} }
} }
return nil
} }
var roomserverInputBackpressure = prometheus.NewGaugeVec( var roomserverInputBackpressure = prometheus.NewGaugeVec(

View file

@ -299,7 +299,7 @@ func (r *Inputer) processRoomEvent(
// allowed at the time, and also to get the history visibility. We won't // allowed at the time, and also to get the history visibility. We won't
// bother doing this if the event was already rejected as it just ends up // bother doing this if the event was already rejected as it just ends up
// burning CPU time. // burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if rejectionErr == nil && !isRejected && !softfail { if rejectionErr == nil && !isRejected && !softfail {
var err error var err error
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
@ -429,7 +429,7 @@ func (r *Inputer) processStateBefore(
input *api.InputRoomEvent, input *api.InputRoomEvent,
missingPrev bool, missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
historyVisibility = gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive. historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared.
event := input.Event.Unwrap() event := input.Event.Unwrap()
isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("")
var stateBeforeEvent []*gomatrixserverlib.Event var stateBeforeEvent []*gomatrixserverlib.Event

View file

@ -43,21 +43,21 @@ func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context, ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest, req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse, res *api.PerformAdminEvacuateRoomResponse,
) { ) error {
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err), Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
} }
return return nil
} }
if roomInfo == nil || roomInfo.IsStub() { if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorNoRoom, Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room %s not found", req.RoomID), Msg: fmt.Sprintf("Room %s not found", req.RoomID),
} }
return return nil
} }
memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
@ -66,7 +66,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err), Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err),
} }
return return nil
} }
memberEvents, err := r.DB.Events(ctx, memberNIDs) memberEvents, err := r.DB.Events(ctx, memberNIDs)
@ -75,7 +75,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.Events: %s", err), Msg: fmt.Sprintf("r.DB.Events: %s", err),
} }
return return nil
} }
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
@ -89,7 +89,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err), Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err),
} }
return return nil
} }
prevEvents := latestRes.LatestEvents prevEvents := latestRes.LatestEvents
@ -104,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Unmarshal: %s", err), Msg: fmt.Sprintf("json.Unmarshal: %s", err),
} }
return return nil
} }
memberContent.Membership = gomatrixserverlib.Leave memberContent.Membership = gomatrixserverlib.Leave
@ -122,7 +122,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Marshal: %s", err), Msg: fmt.Sprintf("json.Marshal: %s", err),
} }
return return nil
} }
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
@ -131,7 +131,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err), Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
} }
return return nil
} }
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes)
@ -140,7 +140,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err), Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
} }
return return nil
} }
inputEvents = append(inputEvents, api.InputRoomEvent{ inputEvents = append(inputEvents, api.InputRoomEvent{
@ -160,28 +160,28 @@ func (r *Admin) PerformAdminEvacuateRoom(
Asynchronous: true, Asynchronous: true,
} }
inputRes := &api.InputRoomEventsResponse{} inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
} }
func (r *Admin) PerformAdminEvacuateUser( func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context, ctx context.Context,
req *api.PerformAdminEvacuateUserRequest, req *api.PerformAdminEvacuateUserRequest,
res *api.PerformAdminEvacuateUserResponse, res *api.PerformAdminEvacuateUserResponse,
) { ) error {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed user ID: %s", err), Msg: fmt.Sprintf("Malformed user ID: %s", err),
} }
return return nil
} }
if domain != r.Cfg.Matrix.ServerName { if domain != r.Cfg.Matrix.ServerName {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: "Can only evacuate local users using this endpoint", Msg: "Can only evacuate local users using this endpoint",
} }
return return nil
} }
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join) roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join)
@ -190,7 +190,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err), Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
} }
return return nil
} }
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite) inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite)
@ -199,7 +199,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err), Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
} }
return return nil
} }
for _, roomID := range append(roomIDs, inviteRoomIDs...) { for _, roomID := range append(roomIDs, inviteRoomIDs...) {
@ -214,7 +214,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err), Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err),
} }
return return nil
} }
if len(outputEvents) == 0 { if len(outputEvents) == 0 {
continue continue
@ -224,9 +224,10 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err), Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err),
} }
return return nil
} }
res.Affected = append(res.Affected, roomID) res.Affected = append(res.Affected, roomID)
} }
return nil
} }

View file

@ -241,7 +241,9 @@ func (r *Inviter) PerformInvite(
}, },
} }
inputRes := &api.InputRoomEventsResponse{} inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil {
return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
}
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),

View file

@ -52,7 +52,7 @@ func (r *Joiner) PerformJoin(
ctx context.Context, ctx context.Context,
req *rsAPI.PerformJoinRequest, req *rsAPI.PerformJoinRequest,
res *rsAPI.PerformJoinResponse, res *rsAPI.PerformJoinResponse,
) { ) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomIDOrAlias, "room_id": req.RoomIDOrAlias,
"user_id": req.UserID, "user_id": req.UserID,
@ -71,11 +71,12 @@ func (r *Joiner) PerformJoin(
Msg: err.Error(), Msg: err.Error(),
} }
} }
return return nil
} }
logger.Info("User joined room successfully") logger.Info("User joined room successfully")
res.RoomID = roomID res.RoomID = roomID
res.JoinedVia = joinedVia res.JoinedVia = joinedVia
return nil
} }
func (r *Joiner) performJoin( func (r *Joiner) performJoin(
@ -291,7 +292,12 @@ func (r *Joiner) performJoinRoomByID(
}, },
} }
inputRes := rsAPI.InputRoomEventsResponse{} inputRes := rsAPI.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNoOperation,
Msg: fmt.Sprintf("InputRoomEvents failed: %s", err),
}
}
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
return "", "", &rsAPI.PerformError{ return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNotAllowed, Code: rsAPI.PerformErrorNotAllowed,

View file

@ -186,7 +186,9 @@ func (r *Leaver) performLeaveRoomByID(
}, },
} }
inputRes := api.InputRoomEventsResponse{} inputRes := api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
}
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
return nil, fmt.Errorf("r.InputRoomEvents: %w", err) return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
} }

View file

@ -44,7 +44,7 @@ func (r *Peeker) PerformPeek(
ctx context.Context, ctx context.Context,
req *api.PerformPeekRequest, req *api.PerformPeekRequest,
res *api.PerformPeekResponse, res *api.PerformPeekResponse,
) { ) error {
roomID, err := r.performPeek(ctx, req) roomID, err := r.performPeek(ctx, req)
if err != nil { if err != nil {
perr, ok := err.(*api.PerformError) perr, ok := err.(*api.PerformError)
@ -57,6 +57,7 @@ func (r *Peeker) PerformPeek(
} }
} }
res.RoomID = roomID res.RoomID = roomID
return nil
} }
func (r *Peeker) performPeek( func (r *Peeker) performPeek(

View file

@ -29,11 +29,12 @@ func (r *Publisher) PerformPublish(
ctx context.Context, ctx context.Context,
req *api.PerformPublishRequest, req *api.PerformPublishRequest,
res *api.PerformPublishResponse, res *api.PerformPublishResponse,
) { ) error {
err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public") err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public")
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Msg: err.Error(), Msg: err.Error(),
} }
} }
return nil
} }

View file

@ -41,7 +41,7 @@ func (r *Unpeeker) PerformUnpeek(
ctx context.Context, ctx context.Context,
req *api.PerformUnpeekRequest, req *api.PerformUnpeekRequest,
res *api.PerformUnpeekResponse, res *api.PerformUnpeekResponse,
) { ) error {
if err := r.performUnpeek(ctx, req); err != nil { if err := r.performUnpeek(ctx, req); err != nil {
perr, ok := err.(*api.PerformError) perr, ok := err.(*api.PerformError)
if ok { if ok {
@ -52,6 +52,7 @@ func (r *Unpeeker) PerformUnpeek(
} }
} }
} }
return nil
} }
func (r *Unpeeker) performUnpeek( func (r *Unpeeker) performUnpeek(

View file

@ -45,12 +45,13 @@ func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context, ctx context.Context,
req *api.PerformRoomUpgradeRequest, req *api.PerformRoomUpgradeRequest,
res *api.PerformRoomUpgradeResponse, res *api.PerformRoomUpgradeResponse,
) { ) error {
res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req) res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req)
if res.Error != nil { if res.Error != nil {
res.NewRoomID = "" res.NewRoomID = ""
logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed") logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed")
} }
return nil
} }
func (r *Upgrader) performRoomUpgrade( func (r *Upgrader) performRoomUpgrade(
@ -286,22 +287,24 @@ func publishNewRoomAndUnpublishOldRoom(
) { ) {
// expose this room in the published room list // expose this room in the published room list
var pubNewRoomRes api.PerformPublishResponse var pubNewRoomRes api.PerformPublishResponse
URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: newRoomID, RoomID: newRoomID,
Visibility: "public", Visibility: "public",
}, &pubNewRoomRes) }, &pubNewRoomRes); err != nil {
if pubNewRoomRes.Error != nil { util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if pubNewRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public") util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public")
} }
var unpubOldRoomRes api.PerformPublishResponse var unpubOldRoomRes api.PerformPublishResponse
// remove the old room from the published room list // remove the old room from the published room list
URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: oldRoomID, RoomID: oldRoomID,
Visibility: "private", Visibility: "private",
}, &unpubOldRoomRes) }, &unpubOldRoomRes); err != nil {
if unpubOldRoomRes.Error != nil { util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if unpubOldRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private") util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private")
} }

View file

@ -21,6 +21,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/acls"
@ -30,9 +34,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
type Queryer struct { type Queryer struct {
@ -204,6 +205,54 @@ func (r *Queryer) QueryMembershipForUser(
return err return err
} }
func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context,
request *api.QueryMembershipAtEventRequest,
response *api.QueryMembershipAtEventResponse,
) error {
response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err)
}
if info == nil {
return fmt.Errorf("no roomInfo found")
}
// get the users stateKeyNID
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
if err != nil {
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
}
if _, ok := stateKeyNIDs[request.UserID]; !ok {
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
}
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID])
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
}
for _, eventID := range request.EventIDs {
stateEntry := stateEntries[eventID]
memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
}
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
for i := range memberships {
ev := memberships[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
res = append(res, ev.Headered(info.RoomVersion))
}
}
response.Memberships[eventID] = res
}
return nil
}
// QueryMembershipsForRoom implements api.RoomserverInternalAPI // QueryMembershipsForRoom implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipsForRoom( func (r *Queryer) QueryMembershipsForRoom(
ctx context.Context, ctx context.Context,
@ -684,7 +733,7 @@ func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForU
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit) users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit)
if err != nil { if err != nil && err != sql.ErrNoRows {
return err return err
} }
for _, user := range users { for _, user := range users {

View file

@ -3,18 +3,16 @@ package inthttp
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/http" "net/http"
"github.com/matrix-org/gomatrixserverlib"
asAPI "github.com/matrix-org/dendrite/appservice/api" asAPI "github.com/matrix-org/dendrite/appservice/api"
fsInputAPI "github.com/matrix-org/dendrite/federationapi/api" fsInputAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
) )
const ( const (
@ -63,6 +61,7 @@ const (
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
) )
type httpRoomserverInternalAPI struct { type httpRoomserverInternalAPI struct {
@ -106,11 +105,10 @@ func (h *httpRoomserverInternalAPI) SetRoomAlias(
request *api.SetRoomAliasRequest, request *api.SetRoomAliasRequest,
response *api.SetRoomAliasResponse, response *api.SetRoomAliasResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias") return httputil.CallInternalRPCAPI(
defer span.Finish() "SetRoomAlias", h.roomserverURL+RoomserverSetRoomAliasPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverSetRoomAliasPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// GetRoomIDForAlias implements RoomserverAliasAPI // GetRoomIDForAlias implements RoomserverAliasAPI
@ -119,11 +117,10 @@ func (h *httpRoomserverInternalAPI) GetRoomIDForAlias(
request *api.GetRoomIDForAliasRequest, request *api.GetRoomIDForAliasRequest,
response *api.GetRoomIDForAliasResponse, response *api.GetRoomIDForAliasResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias") return httputil.CallInternalRPCAPI(
defer span.Finish() "GetRoomIDForAlias", h.roomserverURL+RoomserverGetRoomIDForAliasPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// GetAliasesForRoomID implements RoomserverAliasAPI // GetAliasesForRoomID implements RoomserverAliasAPI
@ -132,11 +129,10 @@ func (h *httpRoomserverInternalAPI) GetAliasesForRoomID(
request *api.GetAliasesForRoomIDRequest, request *api.GetAliasesForRoomIDRequest,
response *api.GetAliasesForRoomIDResponse, response *api.GetAliasesForRoomIDResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID") return httputil.CallInternalRPCAPI(
defer span.Finish() "GetAliasesForRoomID", h.roomserverURL+RoomserverGetAliasesForRoomIDPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// RemoveRoomAlias implements RoomserverAliasAPI // RemoveRoomAlias implements RoomserverAliasAPI
@ -145,11 +141,10 @@ func (h *httpRoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest, request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse, response *api.RemoveRoomAliasResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias") return httputil.CallInternalRPCAPI(
defer span.Finish() "RemoveRoomAlias", h.roomserverURL+RoomserverRemoveRoomAliasPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// InputRoomEvents implements RoomserverInputAPI // InputRoomEvents implements RoomserverInputAPI
@ -157,15 +152,14 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") if err := httputil.CallInternalRPCAPI(
defer span.Finish() "InputRoomEvents", h.roomserverURL+RoomserverInputRoomEventsPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverInputRoomEventsPath ); err != nil {
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.ErrMsg = err.Error() response.ErrMsg = err.Error()
} }
return nil
} }
func (h *httpRoomserverInternalAPI) PerformInvite( func (h *httpRoomserverInternalAPI) PerformInvite(
@ -173,45 +167,32 @@ func (h *httpRoomserverInternalAPI) PerformInvite(
request *api.PerformInviteRequest, request *api.PerformInviteRequest,
response *api.PerformInviteResponse, response *api.PerformInviteResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInvite") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformInvite", h.roomserverURL+RoomserverPerformInvitePath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformInvitePath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpRoomserverInternalAPI) PerformJoin( func (h *httpRoomserverInternalAPI) PerformJoin(
ctx context.Context, ctx context.Context,
request *api.PerformJoinRequest, request *api.PerformJoinRequest,
response *api.PerformJoinResponse, response *api.PerformJoinResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoin") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformJoin", h.roomserverURL+RoomserverPerformJoinPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformJoinPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
} }
func (h *httpRoomserverInternalAPI) PerformPeek( func (h *httpRoomserverInternalAPI) PerformPeek(
ctx context.Context, ctx context.Context,
request *api.PerformPeekRequest, request *api.PerformPeekRequest,
response *api.PerformPeekResponse, response *api.PerformPeekResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPeek") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformPeek", h.roomserverURL+RoomserverPerformPeekPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformPeekPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
} }
func (h *httpRoomserverInternalAPI) PerformInboundPeek( func (h *httpRoomserverInternalAPI) PerformInboundPeek(
@ -219,45 +200,32 @@ func (h *httpRoomserverInternalAPI) PerformInboundPeek(
request *api.PerformInboundPeekRequest, request *api.PerformInboundPeekRequest,
response *api.PerformInboundPeekResponse, response *api.PerformInboundPeekResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInboundPeek") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformInboundPeek", h.roomserverURL+RoomserverPerformInboundPeekPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformInboundPeekPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpRoomserverInternalAPI) PerformUnpeek( func (h *httpRoomserverInternalAPI) PerformUnpeek(
ctx context.Context, ctx context.Context,
request *api.PerformUnpeekRequest, request *api.PerformUnpeekRequest,
response *api.PerformUnpeekResponse, response *api.PerformUnpeekResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUnpeek") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformUnpeek", h.roomserverURL+RoomserverPerformUnpeekPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformUnpeekPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
} }
func (h *httpRoomserverInternalAPI) PerformRoomUpgrade( func (h *httpRoomserverInternalAPI) PerformRoomUpgrade(
ctx context.Context, ctx context.Context,
request *api.PerformRoomUpgradeRequest, request *api.PerformRoomUpgradeRequest,
response *api.PerformRoomUpgradeResponse, response *api.PerformRoomUpgradeResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformRoomUpgrade") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformRoomUpgrade", h.roomserverURL+RoomserverPerformRoomUpgradePath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformRoomUpgradePath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
} }
func (h *httpRoomserverInternalAPI) PerformLeave( func (h *httpRoomserverInternalAPI) PerformLeave(
@ -265,62 +233,43 @@ func (h *httpRoomserverInternalAPI) PerformLeave(
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeave") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformLeave", h.roomserverURL+RoomserverPerformLeavePath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformLeavePath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpRoomserverInternalAPI) PerformPublish( func (h *httpRoomserverInternalAPI) PerformPublish(
ctx context.Context, ctx context.Context,
req *api.PerformPublishRequest, request *api.PerformPublishRequest,
res *api.PerformPublishResponse, response *api.PerformPublishResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPublish") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformPublish", h.roomserverURL+RoomserverPerformPublishPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformPublishPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
} }
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom( func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom(
ctx context.Context, ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest, request *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse, response *api.PerformAdminEvacuateRoomResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformAdminEvacuateRoom", h.roomserverURL+RoomserverPerformAdminEvacuateRoomPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
} }
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser(
ctx context.Context, ctx context.Context,
req *api.PerformAdminEvacuateUserRequest, request *api.PerformAdminEvacuateUserRequest,
res *api.PerformAdminEvacuateUserResponse, response *api.PerformAdminEvacuateUserResponse,
) { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateUser") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformAdminEvacuateUser", h.roomserverURL+RoomserverPerformAdminEvacuateUserPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateUserPath )
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
} }
// QueryLatestEventsAndState implements RoomserverQueryAPI // QueryLatestEventsAndState implements RoomserverQueryAPI
@ -329,11 +278,10 @@ func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest, request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse, response *api.QueryLatestEventsAndStateResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryLatestEventsAndState", h.roomserverURL+RoomserverQueryLatestEventsAndStatePath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryStateAfterEvents implements RoomserverQueryAPI // QueryStateAfterEvents implements RoomserverQueryAPI
@ -342,11 +290,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents(
request *api.QueryStateAfterEventsRequest, request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse, response *api.QueryStateAfterEventsResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryStateAfterEvents", h.roomserverURL+RoomserverQueryStateAfterEventsPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryEventsByID implements RoomserverQueryAPI // QueryEventsByID implements RoomserverQueryAPI
@ -355,11 +302,10 @@ func (h *httpRoomserverInternalAPI) QueryEventsByID(
request *api.QueryEventsByIDRequest, request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse, response *api.QueryEventsByIDResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryEventsByID", h.roomserverURL+RoomserverQueryEventsByIDPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpRoomserverInternalAPI) QueryPublishedRooms( func (h *httpRoomserverInternalAPI) QueryPublishedRooms(
@ -367,11 +313,10 @@ func (h *httpRoomserverInternalAPI) QueryPublishedRooms(
request *api.QueryPublishedRoomsRequest, request *api.QueryPublishedRoomsRequest,
response *api.QueryPublishedRoomsResponse, response *api.QueryPublishedRoomsResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublishedRooms") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryPublishedRooms", h.roomserverURL+RoomserverQueryPublishedRoomsPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryPublishedRoomsPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryMembershipForUser implements RoomserverQueryAPI // QueryMembershipForUser implements RoomserverQueryAPI
@ -380,11 +325,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipForUser(
request *api.QueryMembershipForUserRequest, request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse, response *api.QueryMembershipForUserResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryMembershipForUser", h.roomserverURL+RoomserverQueryMembershipForUserPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryMembershipsForRoom implements RoomserverQueryAPI // QueryMembershipsForRoom implements RoomserverQueryAPI
@ -393,11 +337,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom(
request *api.QueryMembershipsForRoomRequest, request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse, response *api.QueryMembershipsForRoomResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryMembershipsForRoom", h.roomserverURL+RoomserverQueryMembershipsForRoomPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryMembershipsForRoom implements RoomserverQueryAPI // QueryMembershipsForRoom implements RoomserverQueryAPI
@ -406,11 +349,10 @@ func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom(
request *api.QueryServerJoinedToRoomRequest, request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse, response *api.QueryServerJoinedToRoomResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryServerJoinedToRoom", h.roomserverURL+RoomserverQueryServerJoinedToRoomPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryServerJoinedToRoomPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI // QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
@ -419,11 +361,10 @@ func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest, request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse, response *api.QueryServerAllowedToSeeEventResponse,
) (err error) { ) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryServerAllowedToSeeEvent", h.roomserverURL+RoomserverQueryServerAllowedToSeeEventPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryMissingEvents implements RoomServerQueryAPI // QueryMissingEvents implements RoomServerQueryAPI
@ -432,11 +373,10 @@ func (h *httpRoomserverInternalAPI) QueryMissingEvents(
request *api.QueryMissingEventsRequest, request *api.QueryMissingEventsRequest,
response *api.QueryMissingEventsResponse, response *api.QueryMissingEventsResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryMissingEvents", h.roomserverURL+RoomserverQueryMissingEventsPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryStateAndAuthChain implements RoomserverQueryAPI // QueryStateAndAuthChain implements RoomserverQueryAPI
@ -445,11 +385,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain(
request *api.QueryStateAndAuthChainRequest, request *api.QueryStateAndAuthChainRequest,
response *api.QueryStateAndAuthChainResponse, response *api.QueryStateAndAuthChainResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryStateAndAuthChain", h.roomserverURL+RoomserverQueryStateAndAuthChainPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// PerformBackfill implements RoomServerQueryAPI // PerformBackfill implements RoomServerQueryAPI
@ -458,11 +397,10 @@ func (h *httpRoomserverInternalAPI) PerformBackfill(
request *api.PerformBackfillRequest, request *api.PerformBackfillRequest,
response *api.PerformBackfillResponse, response *api.PerformBackfillResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBackfill") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformBackfill", h.roomserverURL+RoomserverPerformBackfillPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverPerformBackfillPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryRoomVersionCapabilities implements RoomServerQueryAPI // QueryRoomVersionCapabilities implements RoomServerQueryAPI
@ -471,11 +409,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities(
request *api.QueryRoomVersionCapabilitiesRequest, request *api.QueryRoomVersionCapabilitiesRequest,
response *api.QueryRoomVersionCapabilitiesResponse, response *api.QueryRoomVersionCapabilitiesResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryRoomVersionCapabilities", h.roomserverURL+RoomserverQueryRoomVersionCapabilitiesPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryRoomVersionForRoom implements RoomServerQueryAPI // QueryRoomVersionForRoom implements RoomServerQueryAPI
@ -488,12 +425,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom(
response.RoomVersion = roomVersion response.RoomVersion = roomVersion
return nil return nil
} }
err := httputil.CallInternalRPCAPI(
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom") "QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath,
defer span.Finish() h.httpClient, ctx, request, response,
)
apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err == nil { if err == nil {
h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
} }
@ -505,11 +440,10 @@ func (h *httpRoomserverInternalAPI) QueryCurrentState(
request *api.QueryCurrentStateRequest, request *api.QueryCurrentStateRequest,
response *api.QueryCurrentStateResponse, response *api.QueryCurrentStateResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryCurrentState", h.roomserverURL+RoomserverQueryCurrentStatePath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpRoomserverInternalAPI) QueryRoomsForUser( func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
@ -517,11 +451,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
request *api.QueryRoomsForUserRequest, request *api.QueryRoomsForUserRequest,
response *api.QueryRoomsForUserResponse, response *api.QueryRoomsForUserResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryRoomsForUser", h.roomserverURL+RoomserverQueryRoomsForUserPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpRoomserverInternalAPI) QueryBulkStateContent( func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
@ -529,68 +462,82 @@ func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
request *api.QueryBulkStateContentRequest, request *api.QueryBulkStateContentRequest,
response *api.QueryBulkStateContentResponse, response *api.QueryBulkStateContentResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryBulkStateContent", h.roomserverURL+RoomserverQueryBulkStateContentPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpRoomserverInternalAPI) QuerySharedUsers( func (h *httpRoomserverInternalAPI) QuerySharedUsers(
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse, ctx context.Context,
request *api.QuerySharedUsersRequest,
response *api.QuerySharedUsersResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers") return httputil.CallInternalRPCAPI(
defer span.Finish() "QuerySharedUsers", h.roomserverURL+RoomserverQuerySharedUsersPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpRoomserverInternalAPI) QueryKnownUsers( func (h *httpRoomserverInternalAPI) QueryKnownUsers(
ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, ctx context.Context,
request *api.QueryKnownUsersRequest,
response *api.QueryKnownUsersResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryKnownUsers", h.roomserverURL+RoomserverQueryKnownUsersPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpRoomserverInternalAPI) QueryAuthChain( func (h *httpRoomserverInternalAPI) QueryAuthChain(
ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse, ctx context.Context,
request *api.QueryAuthChainRequest,
response *api.QueryAuthChainResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAuthChain") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryAuthChain", h.roomserverURL+RoomserverQueryAuthChainPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryAuthChainPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, ctx context.Context,
request *api.QueryServerBannedFromRoomRequest,
response *api.QueryServerBannedFromRoomResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryServerBannedFromRoom", h.roomserverURL+RoomserverQueryServerBannedFromRoomPath,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed( func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed(
ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse, ctx context.Context,
request *api.QueryRestrictedJoinAllowedRequest,
response *api.QueryRestrictedJoinAllowedResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRestrictedJoinAllowed") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryRestrictedJoinAllowed", h.roomserverURL+RoomserverQueryRestrictedJoinAllowed,
h.httpClient, ctx, request, response,
apiURL := h.roomserverURL + RoomserverQueryRestrictedJoinAllowed )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error { func (h *httpRoomserverInternalAPI) PerformForget(
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget") ctx context.Context,
defer span.Finish() request *api.PerformForgetRequest,
response *api.PerformForgetResponse,
apiURL := h.roomserverURL + RoomserverPerformForgetPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"PerformForget", h.roomserverURL+RoomserverPerformForgetPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse) error {
return httputil.CallInternalRPCAPI(
"QueryMembershiptAtEvent", h.roomserverURL+RoomserverQueryMembershipAtEventPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -1,499 +1,201 @@
package inthttp package inthttp
import ( import (
"encoding/json"
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util"
) )
// AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux. // AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux.
// nolint: gocyclo // nolint: gocyclo
func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(RoomserverInputRoomEventsPath, internalAPIMux.Handle(
httputil.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { RoomserverInputRoomEventsPath,
var request api.InputRoomEventsRequest httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents),
var response api.InputRoomEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.InputRoomEvents(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformInvitePath,
httputil.MakeInternalAPI("performInvite", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformInviteRequest RoomserverPerformInvitePath,
var response api.PerformInviteResponse httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformInvite(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformJoinPath,
httputil.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformJoinRequest RoomserverPerformJoinPath,
var response api.PerformJoinResponse httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformJoin(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformLeavePath,
httputil.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformLeaveRequest RoomserverPerformLeavePath,
var response api.PerformLeaveResponse httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformLeave(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformPeekPath,
httputil.MakeInternalAPI("performPeek", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformPeekRequest RoomserverPerformPeekPath,
var response api.PerformPeekResponse httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformPeek(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformInboundPeekPath,
httputil.MakeInternalAPI("performInboundPeek", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformInboundPeekRequest RoomserverPerformInboundPeekPath,
var response api.PerformInboundPeekResponse httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformInboundPeek(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformPeekPath,
httputil.MakeInternalAPI("performUnpeek", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformUnpeekRequest RoomserverPerformUnpeekPath,
var response api.PerformUnpeekResponse httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformUnpeek(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformRoomUpgradePath,
httputil.MakeInternalAPI("performRoomUpgrade", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformRoomUpgradeRequest RoomserverPerformRoomUpgradePath,
var response api.PerformRoomUpgradeResponse httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformRoomUpgrade(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformPublishPath,
httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformPublishRequest RoomserverPerformPublishPath,
var response api.PerformPublishResponse httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformPublish(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath,
httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformAdminEvacuateRoomRequest RoomserverPerformAdminEvacuateRoomPath,
var response api.PerformAdminEvacuateRoomResponse httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformAdminEvacuateRoom(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverPerformAdminEvacuateUserPath,
httputil.MakeInternalAPI("performAdminEvacuateUser", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
var request api.PerformAdminEvacuateUserRequest RoomserverPerformAdminEvacuateUserPath,
var response api.PerformAdminEvacuateUserResponse httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformAdminEvacuateUser(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryPublishedRoomsPath, RoomserverQueryPublishedRoomsPath,
httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms),
var request api.QueryPublishedRoomsRequest
var response api.QueryPublishedRoomsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryPublishedRooms(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryLatestEventsAndStatePath, RoomserverQueryLatestEventsAndStatePath,
httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState),
var request api.QueryLatestEventsAndStateRequest
var response api.QueryLatestEventsAndStateResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryStateAfterEventsPath, RoomserverQueryStateAfterEventsPath,
httputil.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents),
var request api.QueryStateAfterEventsRequest
var response api.QueryStateAfterEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryEventsByIDPath, RoomserverQueryEventsByIDPath,
httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID),
var request api.QueryEventsByIDRequest
var response api.QueryEventsByIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryMembershipForUserPath, RoomserverQueryMembershipForUserPath,
httputil.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser),
var request api.QueryMembershipForUserRequest
var response api.QueryMembershipForUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryMembershipsForRoomPath, RoomserverQueryMembershipsForRoomPath,
httputil.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom),
var request api.QueryMembershipsForRoomRequest
var response api.QueryMembershipsForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryServerJoinedToRoomPath, RoomserverQueryServerJoinedToRoomPath,
httputil.MakeInternalAPI("queryServerJoinedToRoom", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom),
var request api.QueryServerJoinedToRoomRequest
var response api.QueryServerJoinedToRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryServerAllowedToSeeEventPath, RoomserverQueryServerAllowedToSeeEventPath,
httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent),
var request api.QueryServerAllowedToSeeEventRequest
var response api.QueryServerAllowedToSeeEventResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryMissingEventsPath, RoomserverQueryMissingEventsPath,
httputil.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents),
var request api.QueryMissingEventsRequest
var response api.QueryMissingEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryStateAndAuthChainPath, RoomserverQueryStateAndAuthChainPath,
httputil.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain),
var request api.QueryStateAndAuthChainRequest
var response api.QueryStateAndAuthChainResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverPerformBackfillPath, RoomserverPerformBackfillPath,
httputil.MakeInternalAPI("PerformBackfill", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill),
var request api.PerformBackfillRequest
var response api.PerformBackfillResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.PerformBackfill(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverPerformForgetPath, RoomserverPerformForgetPath,
httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget),
var request api.PerformForgetRequest
var response api.PerformForgetResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.PerformForget(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryRoomVersionCapabilitiesPath, RoomserverQueryRoomVersionCapabilitiesPath,
httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities),
var request api.QueryRoomVersionCapabilitiesRequest
var response api.QueryRoomVersionCapabilitiesResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryRoomVersionForRoomPath, RoomserverQueryRoomVersionForRoomPath,
httputil.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom),
var request api.QueryRoomVersionForRoomRequest
var response api.QueryRoomVersionForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverSetRoomAliasPath, RoomserverSetRoomAliasPath,
httputil.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias),
var request api.SetRoomAliasRequest
var response api.SetRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverGetRoomIDForAliasPath, RoomserverGetRoomIDForAliasPath,
httputil.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias),
var request api.GetRoomIDForAliasRequest
var response api.GetRoomIDForAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverGetAliasesForRoomIDPath, RoomserverGetAliasesForRoomIDPath,
httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID),
var request api.GetAliasesForRoomIDRequest
var response api.GetAliasesForRoomIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverRemoveRoomAliasPath, RoomserverRemoveRoomAliasPath,
httputil.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias),
var request api.RemoveRoomAliasRequest
var response api.RemoveRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverQueryCurrentStatePath,
httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryCurrentStateRequest{} RoomserverQueryCurrentStatePath,
response := api.QueryCurrentStateResponse{} httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverQueryRoomsForUserPath,
httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryRoomsForUserRequest{} RoomserverQueryRoomsForUserPath,
response := api.QueryRoomsForUserResponse{} httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverQueryBulkStateContentPath,
httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryBulkStateContentRequest{} RoomserverQueryBulkStateContentPath,
response := api.QueryBulkStateContentResponse{} httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QuerySharedUsersRequest{} RoomserverQuerySharedUsersPath,
response := api.QuerySharedUsersResponse{} httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverQueryKnownUsersPath,
httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryKnownUsersRequest{} RoomserverQueryKnownUsersPath,
response := api.QueryKnownUsersResponse{} httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath,
httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryServerBannedFromRoomRequest{} RoomserverQueryServerBannedFromRoomPath,
response := api.QueryServerBannedFromRoomResponse{} httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverQueryAuthChainPath,
httputil.MakeInternalAPI("queryAuthChain", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryAuthChainRequest{} RoomserverQueryAuthChainPath,
response := api.QueryAuthChainResponse{} httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(RoomserverQueryRestrictedJoinAllowed,
httputil.MakeInternalAPI("queryRestrictedJoinAllowed", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryRestrictedJoinAllowedRequest{} RoomserverQueryRestrictedJoinAllowed,
response := api.QueryRestrictedJoinAllowedResponse{} httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { )
return util.MessageResponse(http.StatusBadRequest, err.Error()) internalAPIMux.Handle(
} RoomserverQueryMembershipAtEventPath,
if err := r.QueryRestrictedJoinAllowed(req.Context(), &request, &response); err != nil { httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent),
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
} }

View file

@ -23,12 +23,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
) )
type StateResolutionStorage interface { type StateResolutionStorage interface {
@ -124,6 +123,61 @@ func (v *StateResolution) LoadStateAtEvent(
return stateEntries, nil return stateEntries, nil
} }
func (v *StateResolution) LoadMembershipAtEvent(
ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID,
) (map[string][]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
defer span.Finish()
// De-dupe snapshotNIDs
snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs
for i := range eventIDs {
eventID := eventIDs[i]
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
}
if snapshotNID == 0 {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
}
snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID)
}
snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap))
for nid := range snapshotNIDMap {
snapshotNIDs = append(snapshotNIDs, nid)
}
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, snapshotNIDs)
if err != nil {
return nil, err
}
result := make(map[string][]types.StateEntry)
for _, stateBlockNIDList := range stateBlockNIDLists {
// Query the membership event for the user at the given stateblocks
stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{
{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
},
})
if err != nil {
return nil, err
}
evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID]
for _, evID := range evIDs {
for _, x := range stateEntryLists {
result[evID] = append(result[evID], x.StateEntries...)
}
}
}
return result, nil
}
// LoadStateAtEvent loads the full state of a room before a particular event. // LoadStateAtEvent loads the full state of a room before a particular event.
func (v *StateResolution) LoadStateAtEventForHistoryVisibility( func (v *StateResolution) LoadStateAtEventForHistoryVisibility(
ctx context.Context, eventID string, ctx context.Context, eventID string,

View file

@ -23,7 +23,7 @@ import (
// DefaultRoomVersion contains the room version that will, by // DefaultRoomVersion contains the room version that will, by
// default, be used to create new rooms on this server. // default, be used to create new rooms on this server.
func DefaultRoomVersion() gomatrixserverlib.RoomVersion { func DefaultRoomVersion() gomatrixserverlib.RoomVersion {
return gomatrixserverlib.RoomVersionV6 return gomatrixserverlib.RoomVersionV9
} }
// RoomVersions returns a map of all known room versions to this // RoomVersions returns a map of all known room versions to this

View file

@ -164,9 +164,9 @@ func TestMSC2836(t *testing.T) {
// make everyone joined to each other's rooms // make everyone joined to each other's rooms
nopRsAPI := &testRoomserverAPI{ nopRsAPI := &testRoomserverAPI{
userToJoinedRooms: map[string][]string{ userToJoinedRooms: map[string][]string{
alice: []string{roomID}, alice: {roomID},
bob: []string{roomID}, bob: {roomID},
charlie: []string{roomID}, charlie: {roomID},
}, },
events: map[string]*gomatrixserverlib.HeaderedEvent{ events: map[string]*gomatrixserverlib.HeaderedEvent{
eventA.EventID(): eventA, eventA.EventID(): eventA,

View file

@ -45,9 +45,6 @@ const (
ConstCreateEventContentValueSpace = "m.space" ConstCreateEventContentValueSpace = "m.space"
ConstSpaceChildEventType = "m.space.child" ConstSpaceChildEventType = "m.space.child"
ConstSpaceParentEventType = "m.space.parent" ConstSpaceParentEventType = "m.space.parent"
ConstJoinRulePublic = "public"
ConstJoinRuleKnock = "knock"
ConstJoinRuleRestricted = "restricted"
) )
type MSC2946ClientResponse struct { type MSC2946ClientResponse struct {
@ -524,11 +521,11 @@ func (w *walker) authorisedServer(roomID string) bool {
return false return false
} }
if rule == ConstJoinRulePublic || rule == ConstJoinRuleKnock { if rule == gomatrixserverlib.Public || rule == gomatrixserverlib.Knock {
return true return true
} }
if rule == ConstJoinRuleRestricted { if rule == gomatrixserverlib.Restricted {
allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...) allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...)
} }
} }
@ -600,9 +597,9 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi
rule, ruleErr := joinRuleEv.JoinRule() rule, ruleErr := joinRuleEv.JoinRule()
if ruleErr != nil { if ruleErr != nil {
util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", parentRoomID).Warn("failed to get join rule") util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", parentRoomID).Warn("failed to get join rule")
} else if rule == ConstJoinRulePublic || rule == ConstJoinRuleKnock { } else if rule == gomatrixserverlib.Public || rule == gomatrixserverlib.Knock {
allowed = true allowed = true
} else if rule == ConstJoinRuleRestricted { } else if rule == gomatrixserverlib.Restricted {
allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership") allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")
// check parent is in the allowed set // check parent is in the allowed set
for _, a := range allowedRoomIDs { for _, a := range allowedRoomIDs {
@ -639,7 +636,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi
func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) { func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) {
rule, _ := joinRuleEv.JoinRule() rule, _ := joinRuleEv.JoinRule()
if rule != ConstJoinRuleRestricted { if rule != gomatrixserverlib.Restricted {
return nil return nil
} }
var jrContent gomatrixserverlib.JoinRuleContent var jrContent gomatrixserverlib.JoinRuleContent

View file

@ -0,0 +1,217 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"math"
"time"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
"github.com/tidwall/gjson"
)
func init() {
prometheus.MustRegister(calculateHistoryVisibilityDuration)
}
// calculateHistoryVisibilityDuration stores the time it takes to
// calculate the history visibility. In polylith mode the roundtrip
// to the roomserver is included in this time.
var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "syncapi",
Name: "calculateHistoryVisibility_duration_millis",
Help: "How long it takes to calculate the history visibility",
Buckets: []float64{ // milliseconds
5, 10, 25, 50, 75, 100, 250, 500,
1000, 2000, 3000, 4000, 5000, 6000,
7000, 8000, 9000, 10000, 15000, 20000,
},
},
[]string{"api"},
)
var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{
gomatrixserverlib.WorldReadable: 0,
gomatrixserverlib.HistoryVisibilityShared: 1,
gomatrixserverlib.HistoryVisibilityInvited: 2,
gomatrixserverlib.HistoryVisibilityJoined: 3,
}
// eventVisibility contains the history visibility and membership state at a given event
type eventVisibility struct {
visibility gomatrixserverlib.HistoryVisibility
membershipAtEvent string
membershipCurrent string
}
// allowed checks the eventVisibility if the user is allowed to see the event.
// Rules as defined by https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
func (ev eventVisibility) allowed() (allowed bool) {
switch ev.visibility {
case gomatrixserverlib.HistoryVisibilityWorldReadable:
// If the history_visibility was set to world_readable, allow.
return true
case gomatrixserverlib.HistoryVisibilityJoined:
// If the users membership was join, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join {
return true
}
return false
case gomatrixserverlib.HistoryVisibilityShared:
// If the users membership was join, allow.
// If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join {
return true
}
return false
case gomatrixserverlib.HistoryVisibilityInvited:
// If the users membership was join, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join {
return true
}
if ev.membershipAtEvent == gomatrixserverlib.Invite {
return true
}
return false
default:
return false
}
}
// ApplyHistoryVisibilityFilter applies the room history visibility filter on gomatrixserverlib.HeaderedEvents.
// Returns the filtered events and an error, if any.
func ApplyHistoryVisibilityFilter(
ctx context.Context,
syncDB storage.Database,
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{},
userID, endpoint string,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
if len(events) == 0 {
return events, nil
}
start := time.Now()
// try to get the current membership of the user
membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64)
if err != nil {
return nil, err
}
// Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events))
visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID())
if err != nil {
return eventsFiltered, err
}
for _, ev := range events {
evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent
// Always include specific state events for /sync responses
if alwaysIncludeEventIDs != nil {
if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok {
eventsFiltered = append(eventsFiltered, ev)
continue
}
}
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) {
eventsFiltered = append(eventsFiltered, ev)
continue
}
// Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive
// of the old vs new.
// https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
if hisVis, err := ev.HistoryVisibility(); err == nil {
prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String()
oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)]
// if we can't get the previous history visibility, default to shared.
if !ok {
oldPrio = historyVisibilityPriority[gomatrixserverlib.HistoryVisibilityShared]
}
// no OK check, since this should have been validated when setting the value
newPrio := historyVisibilityPriority[hisVis]
if oldPrio < newPrio {
evVis.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis)
}
}
// do the actual check
allowed := evVis.allowed()
if allowed {
eventsFiltered = append(eventsFiltered, ev)
}
}
calculateHistoryVisibilityDuration.With(prometheus.Labels{"api": endpoint}).Observe(float64(time.Since(start).Milliseconds()))
return eventsFiltered, nil
}
// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership
// of `userID` at the given event.
// Returns an error if the roomserver can't calculate the memberships.
func visibilityForEvents(
ctx context.Context,
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
userID, roomID string,
) (map[string]eventVisibility, error) {
eventIDs := make([]string, len(events))
for i := range events {
eventIDs[i] = events[i].EventID()
}
result := make(map[string]eventVisibility, len(eventIDs))
// get the membership events for all eventIDs
membershipResp := &api.QueryMembershipAtEventResponse{}
err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{
RoomID: roomID,
EventIDs: eventIDs,
UserID: userID,
}, membershipResp)
if err != nil {
return result, err
}
// Create a map from eventID -> eventVisibility
for _, event := range events {
eventID := event.EventID()
vis := eventVisibility{
membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility,
}
membershipEvs, ok := membershipResp.Memberships[eventID]
if !ok {
result[eventID] = vis
continue
}
for _, ev := range membershipEvs {
membership, err := ev.Membership()
if err != nil {
return result, err
}
vis.membershipAtEvent = membership
}
result[eventID] = vis
}
return result, nil
}

View file

@ -31,7 +31,7 @@ import (
// DeviceOTKCounts adds one-time key counts to the /sync response // DeviceOTKCounts adds one-time key counts to the /sync response
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
var queryRes keyapi.QueryOneTimeKeysResponse var queryRes keyapi.QueryOneTimeKeysResponse
keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
}, &queryRes) }, &queryRes)
@ -73,7 +73,7 @@ func DeviceListCatchup(
offset = int64(from) offset = int64(from)
} }
var queryRes keyapi.QueryKeyChangesResponse var queryRes keyapi.QueryKeyChangesResponse
keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
Offset: offset, Offset: offset,
ToOffset: toOffset, ToOffset: toOffset,
}, &queryRes) }, &queryRes)

View file

@ -22,31 +22,41 @@ var (
type mockKeyAPI struct{} type mockKeyAPI struct{}
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) { func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error {
return nil
} }
func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {}
// PerformClaimKeys claims one-time keys for use in pre-key messages // PerformClaimKeys claims one-time keys for use in pre-key messages
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) { func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error {
return nil
} }
func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) { func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error {
return nil
} }
func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) { func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error {
return nil
} }
func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) { func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error {
return nil
} }
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) { func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error {
return nil
} }
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
return nil
} }
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
return nil
} }
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) { func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error {
return nil
} }
func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) { func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error {
return nil
} }
type mockRoomserverAPI struct { type mockRoomserverAPI struct {

View file

@ -21,10 +21,12 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strconv" "strconv"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
roomserver "github.com/matrix-org/dendrite/roomserver/api" roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -95,24 +97,6 @@ func Context(
ContainsURL: filter.ContainsURL, ContainsURL: filter.ContainsURL,
} }
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
state, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
// verify the user is allowed to see the context for this room/event
for _, x := range state {
var hisVis gomatrixserverlib.HistoryVisibility
hisVis, err = x.HistoryVisibility()
if err != nil {
continue
}
allowed := hisVis == gomatrixserverlib.WorldReadable || membershipRes.Membership == gomatrixserverlib.Join
if !allowed {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("User is not allowed to query context"),
}
}
}
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -125,6 +109,24 @@ func Context(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// verify the user is allowed to see the context for this room/event
startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (context)")
if len(filteredEvents) == 0 {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("User is not allowed to query context"),
}
}
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter) eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch before events") logrus.WithError(err).Error("unable to fetch before events")
@ -137,8 +139,27 @@ func Context(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatAll) startTime = time.Now()
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatAll) eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (context eventsBefore/eventsAfter)")
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
if err != nil {
logrus.WithError(err).Error("unable to fetch current room state")
return jsonerror.InternalServerError()
}
eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll)
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll)
newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache)
response := ContextRespsonse{ response := ContextRespsonse{
@ -162,6 +183,44 @@ func Context(
} }
} }
// applyHistoryVisibilityOnContextEvents is a helper function to avoid roundtrips to the roomserver
// by combining the events before and after the context event. Returns the filtered events,
// and an error, if any.
func applyHistoryVisibilityOnContextEvents(
ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
userID string,
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
eventIDsBefore := make(map[string]struct{}, len(eventsBefore))
eventIDsAfter := make(map[string]struct{}, len(eventsAfter))
// Remember before/after eventIDs, so we can restore them
// after applying history visibility checks
for _, ev := range eventsBefore {
eventIDsBefore[ev.EventID()] = struct{}{}
}
for _, ev := range eventsAfter {
eventIDsAfter[ev.EventID()] = struct{}{}
}
allEvents := append(eventsBefore, eventsAfter...)
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context")
if err != nil {
return nil, nil, err
}
// "Restore" events in the correct context
for _, ev := range filteredEvents {
if _, ok := eventIDsBefore[ev.EventID()]; ok {
filteredBefore = append(filteredBefore, ev)
}
if _, ok := eventIDsAfter[ev.EventID()]; ok {
filteredAfter = append(filteredAfter, ev)
}
}
return filteredBefore, filteredAfter, nil
}
func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if len(startEvents) > 0 { if len(startEvents) > 0 {
start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID()) start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID())

View file

@ -19,6 +19,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"sort" "sort"
"time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -28,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -324,6 +326,9 @@ func (r *messagesReq) retrieveEvents() (
// reliable way to define it), it would be easier and less troublesome to // reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database. // only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events) start, end, err = r.getStartEnd(events)
if err != nil {
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err
}
// Sort the events to ensure we send them in the right order. // Sort the events to ensure we send them in the right order.
if r.backwardOrdering { if r.backwardOrdering {
@ -337,97 +342,18 @@ func (r *messagesReq) retrieveEvents() (
} }
events = reversed(events) events = reversed(events)
} }
events = r.filterHistoryVisible(events)
if len(events) == 0 { if len(events) == 0 {
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil
} }
// Convert all of the events into client events. // Apply room history visibility filter
clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll) startTime := time.Now()
return clientEvents, start, end, err filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
} logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { "room_id": r.roomID,
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the }).Debug("applied history visibility (messages)")
// user shouldn't see, we check the recent events and remove any prior to the join event of the user return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err
// which is equiv to history_visibility: joined
joinEventIndex := -1
for i, ev := range events {
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(r.device.UserID) {
membership, _ := ev.Membership()
if membership == "join" {
joinEventIndex = i
break
}
}
}
var result []*gomatrixserverlib.HeaderedEvent
var eventsToCheck []*gomatrixserverlib.HeaderedEvent
if joinEventIndex != -1 {
if r.backwardOrdering {
result = events[:joinEventIndex+1]
eventsToCheck = append(eventsToCheck, result[0])
} else {
result = events[joinEventIndex:]
eventsToCheck = append(eventsToCheck, result[len(result)-1])
}
} else {
eventsToCheck = []*gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]}
result = events
}
// make sure the user was in the room for both the earliest and latest events, we need this because
// some backpagination results will not have the join event (e.g if they hit /messages at the join event itself)
wasJoined := true
for _, ev := range eventsToCheck {
var queryRes api.QueryStateAfterEventsResponse
err := r.rsAPI.QueryStateAfterEvents(r.ctx, &api.QueryStateAfterEventsRequest{
RoomID: ev.RoomID(),
PrevEventIDs: ev.PrevEventIDs(),
StateToFetch: []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomMember, StateKey: r.device.UserID},
{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""},
},
}, &queryRes)
if err != nil {
wasJoined = false
break
}
var hisVisEvent, membershipEvent *gomatrixserverlib.HeaderedEvent
for i := range queryRes.StateEvents {
switch queryRes.StateEvents[i].Type() {
case gomatrixserverlib.MRoomMember:
membershipEvent = queryRes.StateEvents[i]
case gomatrixserverlib.MRoomHistoryVisibility:
hisVisEvent = queryRes.StateEvents[i]
}
}
if hisVisEvent == nil {
return events // apply no filtering as it defaults to Shared.
}
hisVis, _ := hisVisEvent.HistoryVisibility()
if hisVis == "shared" || hisVis == "world_readable" {
return events // apply no filtering
}
if membershipEvent == nil {
wasJoined = false
break
}
membership, err := membershipEvent.Membership()
if err != nil {
wasJoined = false
break
}
if membership != "join" {
wasJoined = false
break
}
}
if !wasJoined {
util.GetLogger(r.ctx).WithField("num_events", len(events)).Warnf("%s was not joined to room during these events, omitting them", r.device.UserID)
return []*gomatrixserverlib.HeaderedEvent{}
}
return result
} }
func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {

View file

@ -161,6 +161,10 @@ type Database interface {
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
} }
type Presence interface { type Presence interface {

View file

@ -17,7 +17,10 @@ package deltas
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
) )
func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error { func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
@ -31,6 +34,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T
return nil return nil
} }
// UpSetHistoryVisibility sets the history visibility for already stored events.
// Requires current_room_state and output_room_events to be created.
func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error {
// get the current room history visibilities
historyVisibilities, err := currentHistoryVisibilities(ctx, tx)
if err != nil {
return err
}
// update the history visibility
for roomID, hisVis := range historyVisibilities {
_, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1
WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID)
if err != nil {
return fmt.Errorf("failed to update history visibility: %w", err)
}
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error { func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2; ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
@ -39,9 +63,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
return nil return nil
} }
// currentHistoryVisibilities returns a map from roomID to current history visibility.
// If the history visibility was changed after room creation, defaults to joined.
func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) {
rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state
WHERE type = 'm.room.history_visibility' AND state_key = '';
`)
if err != nil {
return nil, fmt.Errorf("failed to query current room state: %w", err)
}
defer rows.Close() // nolint: errcheck
var eventBytes []byte
var roomID string
var event gomatrixserverlib.HeaderedEvent
var hisVis gomatrixserverlib.HistoryVisibility
historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility)
for rows.Next() {
if err = rows.Scan(&roomID, &eventBytes); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
if err = json.Unmarshal(eventBytes, &event); err != nil {
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined
if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 {
historyVisibilities[roomID] = hisVis
}
}
return historyVisibilities, nil
}
func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error { func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility; ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility;

View file

@ -66,10 +66,14 @@ const selectMembershipCountSQL = "" +
const selectHeroesSQL = "" + const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
type membershipsStatements struct { type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt
selectHeroesStmt *sql.Stmt selectHeroesStmt *sql.Stmt
selectMembershipForUserStmt *sql.Stmt
} }
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -82,6 +86,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectHeroesStmt, selectHeroesSQL}, {&s.selectHeroesStmt, selectHeroesSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -132,3 +137,20 @@ func (s *membershipsStatements) SelectHeroes(
} }
return heroes, rows.Err() return heroes, rows.Err()
} }
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
func (s *membershipsStatements) SelectMembershipForUser(
ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64,
) (membership string, topologyPos int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos)
if err != nil {
if err == sql.ErrNoRows {
return "leave", 0, nil
}
return "", 0, err
}
return membership, topologyPos, nil
}

View file

@ -191,10 +191,12 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{ m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)", Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
}) },
)
err = m.Up(context.Background()) err = m.Up(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
) )
@ -97,6 +98,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// apply migrations which need multiple tables
m := sqlutil.NewMigrator(d.db)
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: set history visibility for existing events",
Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created.
},
)
err = m.Up(base.Context())
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,

View file

@ -231,7 +231,7 @@ func (d *Database) AddPeek(
return return
} }
// DeletePeeks tracks the fact that a user has stopped peeking from the specified // DeletePeek tracks the fact that a user has stopped peeking from the specified
// device. If the peeks was successfully deleted this returns the stream ID it was // device. If the peeks was successfully deleted this returns the stream ID it was
// stored at. Returns an error if there was a problem communicating with the database. // stored at. Returns an error if there was a problem communicating with the database.
func (d *Database) DeletePeek( func (d *Database) DeletePeek(
@ -372,6 +372,7 @@ func (d *Database) WriteEvent(
) (pduPosition types.StreamPosition, returnErr error) { ) (pduPosition types.StreamPosition, returnErr error) {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error var err error
ev.Visibility = historyVisibility
pos, err := d.OutputEvents.InsertEvent( pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility, ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
) )
@ -563,7 +564,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
return err return err
} }
// Retrieve the backward topology position, i.e. the position of the // GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
// oldest event in the room's topology. // oldest event in the room's topology.
func (d *Database) GetBackwardTopologyPos( func (d *Database) GetBackwardTopologyPos(
ctx context.Context, ctx context.Context,
@ -674,7 +675,7 @@ func (d *Database) fetchMissingStateEvents(
return events, nil return events, nil
} }
// getStateDeltas returns the state deltas between fromPos and toPos, // GetStateDeltas returns the state deltas between fromPos and toPos,
// exclusive of oldPos, inclusive of newPos, for the rooms in which // exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events. // the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
@ -812,7 +813,7 @@ func (d *Database) GetStateDeltas(
return deltas, joinedRoomIDs, nil return deltas, joinedRoomIDs, nil
} }
// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync // GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// requests with full_state=true. // requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get // Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms. // updates for other rooms.
@ -1039,37 +1040,41 @@ func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID s
return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to) return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to)
} }
func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
} }
func (s *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
} }
func (s *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
} }
func (s *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
return s.Ignores.SelectIgnores(ctx, userID) return d.Ignores.SelectIgnores(ctx, userID)
} }
func (s *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
return s.Ignores.UpsertIgnores(ctx, userID, ignores) return d.Ignores.UpsertIgnores(ctx, userID, ignores)
} }
func (s *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
return s.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync) return d.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
} }
func (s *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return s.Presence.GetPresenceForUser(ctx, nil, userID) return d.Presence.GetPresenceForUser(ctx, nil, userID)
} }
func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return s.Presence.GetPresenceAfter(ctx, nil, after, filter) return d.Presence.GetPresenceAfter(ctx, nil, after, filter)
} }
func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return s.Presence.GetMaxPresenceID(ctx, nil) return d.Presence.GetMaxPresenceID(ctx, nil)
}
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
} }

View file

@ -17,7 +17,10 @@ package deltas
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
) )
func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error { func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
@ -37,6 +40,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T
return nil return nil
} }
// UpSetHistoryVisibility sets the history visibility for already stored events.
// Requires current_room_state and output_room_events to be created.
func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error {
// get the current room history visibilities
historyVisibilities, err := currentHistoryVisibilities(ctx, tx)
if err != nil {
return err
}
// update the history visibility
for roomID, hisVis := range historyVisibilities {
_, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1
WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID)
if err != nil {
return fmt.Errorf("failed to update history visibility: %w", err)
}
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error { func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists. // SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
// Required for unit tests, as otherwise a duplicate column error will show up. // Required for unit tests, as otherwise a duplicate column error will show up.
@ -51,9 +75,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
return nil return nil
} }
// currentHistoryVisibilities returns a map from roomID to current history visibility.
// If the history visibility was changed after room creation, defaults to joined.
func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) {
rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state
WHERE type = 'm.room.history_visibility' AND state_key = '';
`)
if err != nil {
return nil, fmt.Errorf("failed to query current room state: %w", err)
}
defer rows.Close() // nolint: errcheck
var eventBytes []byte
var roomID string
var event gomatrixserverlib.HeaderedEvent
var hisVis gomatrixserverlib.HistoryVisibility
historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility)
for rows.Next() {
if err = rows.Scan(&roomID, &eventBytes); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
if err = json.Unmarshal(eventBytes, &event); err != nil {
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined
if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 {
historyVisibilities[roomID] = hisVis
}
}
return historyVisibilities, nil
}
func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error { func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. // SQLite doesn't have "if exists", so check if the column exists.
_, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1") _, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")

View file

@ -66,11 +66,15 @@ const selectMembershipCountSQL = "" +
const selectHeroesSQL = "" + const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
type membershipsStatements struct { type membershipsStatements struct {
db *sql.DB db *sql.DB
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt
//selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic
selectMembershipForUserStmt *sql.Stmt
} }
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -84,6 +88,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
// {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
}.Prepare(db) }.Prepare(db)
} }
@ -148,3 +153,20 @@ func (s *membershipsStatements) SelectHeroes(
} }
return heroes, rows.Err() return heroes, rows.Err()
} }
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
func (s *membershipsStatements) SelectMembershipForUser(
ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64,
) (membership string, topologyPos int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos)
if err != nil {
if err == sql.ErrNoRows {
return "leave", 0, nil
}
return "", 0, err
}
return membership, topologyPos, nil
}

View file

@ -139,10 +139,12 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
} }
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{ m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)", Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
}) },
)
err = m.Up(context.Background()) err = m.Up(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -16,12 +16,14 @@
package sqlite3 package sqlite3
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
) )
// SyncServerDatasource represents a sync server datasource which manages // SyncServerDatasource represents a sync server datasource which manages
@ -41,13 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err return nil, err
} }
if err = d.prepare(); err != nil { if err = d.prepare(base.Context()); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil
} }
func (d *SyncServerDatasource) prepare() (err error) { func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
if err = d.streamID.Prepare(d.db); err != nil { if err = d.streamID.Prepare(d.db); err != nil {
return err return err
} }
@ -107,6 +109,19 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
// apply migrations which need multiple tables
m := sqlutil.NewMigrator(d.db)
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: set history visibility for existing events",
Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created.
},
)
err = m.Up(ctx)
if err != nil {
return err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,

View file

@ -12,20 +12,22 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
var ctx = context.Background() var ctx = context.Background()
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func(), func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewSyncServerDatasource(nil, &config.DatabaseOptions{ base, closeBase := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewSyncServerDatasource(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}) })
if err != nil { if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err) t.Fatalf("NewSyncServerDatasource returned %s", err)
} }
return db, close return db, close, closeBase
} }
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
@ -51,8 +53,9 @@ func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser(t) alice := test.NewUser(t)
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType) db, close, closeBase := MustCreateDatabase(t, dbType)
defer close() defer close()
defer closeBase()
MustWriteEvents(t, db, r.Events()) MustWriteEvents(t, db, r.Events())
}) })
} }
@ -60,8 +63,9 @@ func TestWriteEvents(t *testing.T) {
// These tests assert basic functionality of RecentEvents for PDUs // These tests assert basic functionality of RecentEvents for PDUs
func TestRecentEventsPDU(t *testing.T) { func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close, closeBase := MustCreateDatabase(t, dbType)
defer close() defer close()
defer closeBase()
alice := test.NewUser(t) alice := test.NewUser(t)
// dummy room to make sure SQL queries are filtering on room ID // dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
@ -163,8 +167,9 @@ func TestRecentEventsPDU(t *testing.T) {
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) { func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close, closeBase := MustCreateDatabase(t, dbType)
defer close() defer close()
defer closeBase()
alice := test.NewUser(t) alice := test.NewUser(t)
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -404,8 +409,9 @@ func TestSendToDeviceBehaviour(t *testing.T) {
bob := test.NewUser(t) bob := test.NewUser(t)
deviceID := "one" deviceID := "one"
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close, closeBase := MustCreateDatabase(t, dbType)
defer close() defer close()
defer closeBase()
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
_, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)

View file

@ -185,6 +185,7 @@ type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
} }
type NotificationData interface { type NotificationData interface {

View file

@ -10,10 +10,13 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"go.uber.org/atomic" "go.uber.org/atomic"
@ -123,7 +126,7 @@ func (p *PDUStreamProvider) CompleteSync(
defer reqWaitGroup.Done() defer reqWaitGroup.Done()
jr, jerr := p.getJoinResponseForCompleteSync( jr, jerr := p.getJoinResponseForCompleteSync(
ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
) )
if jerr != nil { if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
@ -149,7 +152,7 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted { if !peek.Deleted {
var jr *types.JoinResponse var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync( jr, err = p.getJoinResponseForCompleteSync(
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
) )
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -281,12 +284,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
} }
} }
if len(recentEvents) > 0 {
updateLatestPosition(recentEvents[len(recentEvents)-1].EventID())
}
if len(delta.StateEvents) > 0 {
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
}
if stateFilter.LazyLoadMembers { if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers( delta.StateEvents, err = p.lazyLoadMembers(
@ -306,6 +303,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
} }
// Applies the history visibility rules
events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
if len(events) > 0 {
updateLatestPosition(events[len(events)-1].EventID())
}
if len(delta.StateEvents) > 0 {
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
}
switch delta.Membership { switch delta.Membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
@ -313,14 +323,17 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition) p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition)
} }
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.RoomID] = *jr res.Rooms.Join[delta.RoomID] = *jr
case gomatrixserverlib.Peek: case gomatrixserverlib.Peek:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
@ -330,12 +343,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
fallthrough // transitions to leave are the same as ban fallthrough // transitions to leave are the same as ban
case gomatrixserverlib.Ban: case gomatrixserverlib.Ban:
// TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Leave[delta.RoomID] = *lr res.Rooms.Leave[delta.RoomID] = *lr
} }
@ -343,6 +356,41 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
return latestPosition, nil return latestPosition, nil
} }
// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make
// sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter(
ctx context.Context,
db storage.Database,
rsAPI roomserverAPI.SyncRoomserverAPI,
roomID, userID string,
limit int,
recentEvents []*gomatrixserverlib.HeaderedEvent,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We need to make sure we always include the latest states events, if they are in the timeline.
// We grep at least limit * 2 events, to ensure we really get the needed events.
stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
if err != nil {
// Not a fatal error, we can continue without the stateEvents,
// they are only needed if there are state events in the timeline.
logrus.WithError(err).Warnf("failed to get current room state")
}
alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents))
for _, ev := range stateEvents {
alwaysIncludeIDs[ev.EventID()] = struct{}{}
}
startTime := time.Now()
events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
if err != nil {
return nil, err
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (sync)")
return events, nil
}
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room. // Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
@ -390,6 +438,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
eventFilter *gomatrixserverlib.RoomEventFilter, eventFilter *gomatrixserverlib.RoomEventFilter,
wantFullState bool, wantFullState bool,
device *userapi.Device, device *userapi.Device,
isPeek bool,
) (jr *types.JoinResponse, err error) { ) (jr *types.JoinResponse, err error) {
jr = types.NewJoinResponse() jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
@ -404,33 +453,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return return
} }
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
// user shouldn't see, we check the recent events and remove any prior to the join event of the user
// which is equiv to history_visibility: joined
joinEventIndex := -1
for i := len(recentStreamEvents) - 1; i >= 0; i-- {
ev := recentStreamEvents[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(device.UserID) {
membership, _ := ev.Membership()
if membership == "join" {
joinEventIndex = i
if i > 0 {
// the create event happens before the first join, so we should cut it at that point instead
if recentStreamEvents[i-1].Type() == gomatrixserverlib.MRoomCreate && recentStreamEvents[i-1].StateKeyEquals("") {
joinEventIndex = i - 1
break
}
}
break
}
}
}
if joinEventIndex != -1 {
// cut all events earlier than the join (but not the join itself)
recentStreamEvents = recentStreamEvents[joinEventIndex:]
limited = false // so clients know not to try to backpaginate
}
// Work our way through the timeline events and pick out the event IDs // Work our way through the timeline events and pick out the event IDs
// of any state events that appear in the timeline. We'll specifically // of any state events that appear in the timeline. We'll specifically
// exclude them at the next step, so that we don't get duplicate state // exclude them at the next step, so that we don't get duplicate state
@ -474,6 +496,19 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
events := recentEvents
// Only apply history visibility checks if the response is for joined rooms
if !isPeek {
events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
}
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
limited = limited && len(events) == len(recentEvents)
if stateFilter.LazyLoadMembers { if stateFilter.LazyLoadMembers {
if err != nil { if err != nil {
return nil, err return nil, err
@ -488,8 +523,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
return jr, nil return jr, nil
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
@ -54,6 +55,16 @@ func (s *syncRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *rsap
return nil return nil
} }
func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsapi.QueryMembershipForUserRequest, res *rsapi.QueryMembershipForUserResponse) error {
res.IsRoomForgotten = false
res.RoomExists = true
return nil
}
func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error {
return nil
}
type syncUserAPI struct { type syncUserAPI struct {
userapi.SyncUserAPI userapi.SyncUserAPI
accounts []userapi.Device accounts []userapi.Device
@ -78,10 +89,11 @@ type syncKeyAPI struct {
keyapi.SyncKeyAPI keyapi.SyncKeyAPI
} }
func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
return nil
} }
func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
return nil
} }
func TestSyncAPIAccessTokens(t *testing.T) { func TestSyncAPIAccessTokens(t *testing.T) {
@ -106,7 +118,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events()) msgs := toNATSMsgs(t, base, room.Events()...)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
testrig.MustPublishMsgs(t, jsctx, msgs...) testrig.MustPublishMsgs(t, jsctx, msgs...)
@ -199,7 +211,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
// m.room.power_levels // m.room.power_levels
// m.room.join_rules // m.room.join_rules
// m.room.history_visibility // m.room.history_visibility
msgs := toNATSMsgs(t, base, room.Events()) msgs := toNATSMsgs(t, base, room.Events()...)
sinceTokens := make([]string, len(msgs)) sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
for i, msg := range msgs { for i, msg := range msgs {
@ -314,6 +326,174 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
} }
// This is mainly what Sytest is doing in "test_history_visibility"
func TestMessageHistoryVisibility(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testHistoryVisibility(t, dbType)
})
}
func testHistoryVisibility(t *testing.T, dbType test.DBType) {
type result struct {
seeWithoutJoin bool
seeBeforeJoin bool
seeAfterInvite bool
}
// create the users
alice := test.NewUser(t)
bob := test.NewUser(t)
bobDev := userapi.Device{
ID: "BOBID",
UserID: bob.ID,
AccessToken: "BOD_BEARER_TOKEN",
DisplayName: "BOB",
}
ctx := context.Background()
// check guest and normal user accounts
for _, accType := range []userapi.AccountType{userapi.AccountTypeGuest, userapi.AccountTypeUser} {
testCases := []struct {
historyVisibility gomatrixserverlib.HistoryVisibility
wantResult result
}{
{
historyVisibility: gomatrixserverlib.HistoryVisibilityWorldReadable,
wantResult: result{
seeWithoutJoin: true,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityShared,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityInvited,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: false,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityJoined,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: false,
seeAfterInvite: false,
},
},
}
bobDev.AccountType = accType
userType := "guest"
if accType == userapi.AccountTypeUser {
userType = "real user"
}
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
// Use the actual internal roomserver API
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{})
for _, tc := range testCases {
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
t.Run(testname, func(t *testing.T) {
// create a room with the given visibility
room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility))
// send the events/messages to NATS to create the rooms
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)})
eventsToSend := append(room.Events(), beforeJoinEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// There is only one event, we expect only to be able to see this, if the room is world_readable
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
// We only care about the returned events at this point
var res struct {
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
}
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
verifyEventVisible(t, tc.wantResult.seeWithoutJoin, beforeJoinEv, res.Chunk)
// Create invite, a message, join the room and create another message.
inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID))
afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)})
joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID))
msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)})
eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// Verify the messages after/before invite are visible or not
w = httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
// verify results
verifyEventVisible(t, tc.wantResult.seeBeforeJoin, beforeJoinEv, res.Chunk)
verifyEventVisible(t, tc.wantResult.seeAfterInvite, afterInviteEv, res.Chunk)
})
}
}
}
func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []gomatrixserverlib.ClientEvent) {
t.Helper()
if wantVisible {
for _, ev := range chunk {
if ev.EventID == wantVisibleEvent.EventID() {
return
}
}
t.Fatalf("expected to see event %s but didn't: %+v", wantVisibleEvent.EventID(), chunk)
} else {
for _, ev := range chunk {
if ev.EventID == wantVisibleEvent.EventID() {
t.Fatalf("expected not to see event %s: %+v", wantVisibleEvent.EventID(), string(ev.Content))
}
}
}
}
func TestSendToDevice(t *testing.T) { func TestSendToDevice(t *testing.T) {
test.WithAllDatabases(t, testSendToDevice) test.WithAllDatabases(t, testSendToDevice)
} }
@ -447,7 +627,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
} }
} }
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg { func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
result := make([]*nats.Msg, len(input)) result := make([]*nats.Msg, len(input))
for i, ev := range input { for i, ev := range input {
var addsStateIDs []string var addsStateIDs []string
@ -459,6 +639,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
NewRoomEvent: &rsapi.OutputNewRoomEvent{ NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev, Event: ev,
AddsStateEventIDs: addsStateIDs, AddsStateEventIDs: addsStateIDs,
HistoryVisibility: ev.Visibility,
}, },
}) })
} }

View file

@ -49,7 +49,3 @@ Notifications can be viewed with GET /notifications
If remote user leaves room we no longer receive device updates If remote user leaves room we no longer receive device updates
Guest users can join guest_access rooms Guest users can join guest_access rooms
# You'll be shocked to discover this is flakey too
Inbound /v1/send_join rejects joins from other servers

View file

@ -111,8 +111,6 @@ Newly joined room includes presence in incremental sync
User is offline if they set_presence=offline in their sync User is offline if they set_presence=offline in their sync
Changes to state are included in an incremental sync Changes to state are included in an incremental sync
A change to displayname should appear in incremental /sync A change to displayname should appear in incremental /sync
Current state appears in timeline in private history
Current state appears in timeline in private history with many messages before
Rooms a user is invited to appear in an initial sync Rooms a user is invited to appear in an initial sync
Rooms a user is invited to appear in an incremental sync Rooms a user is invited to appear in an incremental sync
Sync can be polled for updates Sync can be polled for updates
@ -459,7 +457,6 @@ After changing password, a different session no longer works by default
Read markers appear in incremental v2 /sync Read markers appear in incremental v2 /sync
Read markers appear in initial v2 /sync Read markers appear in initial v2 /sync
Read markers can be updated Read markers can be updated
Local users can peek into world_readable rooms by room ID
We can't peek into rooms with shared history_visibility We can't peek into rooms with shared history_visibility
We can't peek into rooms with invited history_visibility We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility We can't peek into rooms with joined history_visibility
@ -721,4 +718,30 @@ Setting state twice is idempotent
Joining room twice is idempotent Joining room twice is idempotent
Inbound federation can return missing events for shared visibility Inbound federation can return missing events for shared visibility
Inbound federation ignores redactions from invalid servers room > v3 Inbound federation ignores redactions from invalid servers room > v3
Joining room twice is idempotent
Getting messages going forward is limited for a departed room (SPEC-216)
m.room.history_visibility == "shared" allows/forbids appropriately for Guest users
m.room.history_visibility == "invited" allows/forbids appropriately for Guest users
m.room.history_visibility == "default" allows/forbids appropriately for Guest users
m.room.history_visibility == "shared" allows/forbids appropriately for Real users
m.room.history_visibility == "invited" allows/forbids appropriately for Real users
m.room.history_visibility == "default" allows/forbids appropriately for Real users
Guest users can sync from world_readable guest_access rooms if joined
Guest users can sync from shared guest_access rooms if joined
Guest users can sync from invited guest_access rooms if joined
Guest users can sync from joined guest_access rooms if joined
Guest users can sync from default guest_access rooms if joined
Real users can sync from world_readable guest_access rooms if joined
Real users can sync from shared guest_access rooms if joined
Real users can sync from invited guest_access rooms if joined
Real users can sync from joined guest_access rooms if joined
Real users can sync from default guest_access rooms if joined
Only see history_visibility changes on boundaries
Current state appears in timeline in private history
Current state appears in timeline in private history with many messages before
Local users can peek into world_readable rooms by room ID
Newly joined room includes presence in incremental sync Newly joined room includes presence in incremental sync
User in private room doesn't appear in user directory
User joining then leaving public room appears and dissappears from directory
User in remote room doesn't appear in user directory after server left room
User in shared private room does appear in user directory until leave

View file

@ -40,6 +40,7 @@ type Room struct {
ID string ID string
Version gomatrixserverlib.RoomVersion Version gomatrixserverlib.RoomVersion
preset Preset preset Preset
visibility gomatrixserverlib.HistoryVisibility
creator *User creator *User
authEvents gomatrixserverlib.AuthEvents authEvents gomatrixserverlib.AuthEvents
@ -61,6 +62,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
preset: PresetPublicChat, preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9, Version: gomatrixserverlib.RoomVersionV9,
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent), currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
visibility: gomatrixserverlib.HistoryVisibilityShared,
} }
for _, m := range modifiers { for _, m := range modifiers {
m(t, r) m(t, r)
@ -97,10 +99,14 @@ func (r *Room) insertCreateEvents(t *testing.T) {
fallthrough fallthrough
case PresetPrivateChat: case PresetPrivateChat:
joinRule.JoinRule = "invite" joinRule.JoinRule = "invite"
hisVis.HistoryVisibility = "shared" hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
case PresetPublicChat: case PresetPublicChat:
joinRule.JoinRule = "public" joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared" hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
}
if r.visibility != "" {
hisVis.HistoryVisibility = r.visibility
} }
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
@ -183,7 +189,9 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
} }
return ev.Headered(r.Version) headeredEvent := ev.Headered(r.Version)
headeredEvent.Visibility = r.visibility
return headeredEvent
} }
// Add a new event to this room DAG. Not thread-safe. // Add a new event to this room DAG. Not thread-safe.
@ -242,6 +250,12 @@ func RoomPreset(p Preset) roomModifier {
} }
} }
func RoomHistoryVisibility(vis gomatrixserverlib.HistoryVisibility) roomModifier {
return func(t *testing.T, r *Room) {
r.visibility = vis
}
}
func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
return func(t *testing.T, r *Room) { return func(t *testing.T, r *Room) {
r.Version = ver r.Version = ver

View file

@ -20,6 +20,7 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -46,6 +47,7 @@ var (
type User struct { type User struct {
ID string ID string
accountType api.AccountType
// key ID and private key of the server who has this user, if known. // key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey privKey ed25519.PrivateKey
@ -62,6 +64,12 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve
} }
} }
func WithAccountType(accountType api.AccountType) UserOpt {
return func(u *User) {
u.accountType = accountType
}
}
func NewUser(t *testing.T, opts ...UserOpt) *User { func NewUser(t *testing.T, opts ...UserOpt) *User {
counter := atomic.AddInt64(&userIDCounter, 1) counter := atomic.AddInt64(&userIDCounter, 1)
var u User var u User

View file

@ -100,7 +100,7 @@ type ClientUserAPI interface {
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
@ -336,6 +336,7 @@ type PerformAccountCreationResponse struct {
type PerformPasswordUpdateRequest struct { type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account. Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set. Password string // Required: The new password to set.
LogoutDevices bool // Optional: Whether to log out all user devices.
} }
// PerformAccountCreationResponse is the response for PerformAccountCreation // PerformAccountCreationResponse is the response for PerformAccountCreation

View file

@ -94,9 +94,10 @@ func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *Per
util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
return err return err
} }
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error {
t.Impl.QueryKeyBackup(ctx, req, res) err := t.Impl.QueryKeyBackup(ctx, req, res)
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error { func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error {
err := t.Impl.QueryProfile(ctx, req, res) err := t.Impl.QueryProfile(ctx, req, res)

View file

@ -139,6 +139,11 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err return err
} }
if req.LogoutDevices {
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
return err
}
}
res.PasswordUpdated = true res.PasswordUpdated = true
return nil return nil
} }
@ -192,7 +197,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
} }
deleteRes := &keyapi.PerformDeleteKeysResponse{} deleteRes := &keyapi.PerformDeleteKeysResponse{}
a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes) if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
return err
}
if err := deleteRes.Error; err != nil { if err := deleteRes.Error; err != nil {
return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err) return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err)
} }
@ -211,10 +218,12 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er
} }
var uploadRes keyapi.PerformUploadKeysResponse var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: userID, UserID: userID,
DeviceKeys: deviceKeys, DeviceKeys: deviceKeys,
}, &uploadRes) }, &uploadRes); err != nil {
return err
}
if uploadRes.Error != nil { if uploadRes.Error != nil {
return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error) return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error)
} }
@ -268,7 +277,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key // display name has changed: update the device key
var uploadRes keyapi.PerformUploadKeysResponse var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: req.RequestingUserID, UserID: req.RequestingUserID,
DeviceKeys: []keyapi.DeviceKeys{ DeviceKeys: []keyapi.DeviceKeys{
{ {
@ -279,7 +288,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
}, },
}, },
OnlyDisplayNameUpdates: true, OnlyDisplayNameUpdates: true,
}, &uploadRes) }, &uploadRes); err != nil {
return err
}
if uploadRes.Error != nil { if uploadRes.Error != nil {
return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error) return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error)
} }
@ -479,7 +490,9 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
} }
evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{}
a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes) if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil {
return err
}
if err := evacuateRes.Error; err != nil { if err := evacuateRes.Error; err != nil {
logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation") logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation")
} }
@ -538,9 +551,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
if req.Version == "" { if req.Version == "" {
res.BadInput = true res.BadInput = true
res.Error = "must specify a version to delete" res.Error = "must specify a version to delete"
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil return nil
} }
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
@ -549,9 +559,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
res.Exists = exists res.Exists = exists
res.Version = req.Version res.Version = req.Version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil return nil
} }
// Create metadata // Create metadata
@ -562,9 +569,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
res.Exists = err == nil res.Exists = err == nil
res.Version = version res.Version = version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil return nil
} }
// Update metadata // Update metadata
@ -575,16 +579,10 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
res.Exists = err == nil res.Exists = err == nil
res.Version = req.Version res.Version = req.Version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil return nil
} }
// Upload Keys for a specific version metadata // Upload Keys for a specific version metadata
a.uploadBackupKeys(ctx, req, res) a.uploadBackupKeys(ctx, req, res)
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil return nil
} }
@ -627,16 +625,16 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
res.KeyETag = etag res.KeyETag = etag
} }
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error {
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version res.Version = version
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
res.Exists = false res.Exists = false
return return nil
} }
res.Error = fmt.Sprintf("failed to query key backup: %s", err) res.Error = fmt.Sprintf("failed to query key backup: %s", err)
return return nil
} }
res.Algorithm = algorithm res.Algorithm = algorithm
res.AuthData = authData res.AuthData = authData
@ -648,15 +646,16 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err) res.Error = fmt.Sprintf("failed to count keys: %s", err)
} }
return return nil
} }
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err) res.Error = fmt.Sprintf("failed to query keys: %s", err)
return return nil
} }
res.Keys = result res.Keys = result
return nil
} }
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
) )
// HTTP paths for the internal HTTP APIs // HTTP paths for the internal HTTP APIs
@ -84,11 +83,10 @@ type httpUserInternalAPI struct {
} }
func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData") return httputil.CallInternalRPCAPI(
defer span.Finish() "InputAccountData", h.apiURL+InputAccountDataPath,
h.httpClient, ctx, req, res,
apiURL := h.apiURL + InputAccountDataPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) PerformAccountCreation( func (h *httpUserInternalAPI) PerformAccountCreation(
@ -96,11 +94,10 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
request *api.PerformAccountCreationRequest, request *api.PerformAccountCreationRequest,
response *api.PerformAccountCreationResponse, response *api.PerformAccountCreationResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountCreation") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformAccountCreation", h.apiURL+PerformAccountCreationPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformAccountCreationPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformPasswordUpdate( func (h *httpUserInternalAPI) PerformPasswordUpdate(
@ -108,11 +105,10 @@ func (h *httpUserInternalAPI) PerformPasswordUpdate(
request *api.PerformPasswordUpdateRequest, request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse, response *api.PerformPasswordUpdateResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformPasswordUpdate", h.apiURL+PerformPasswordUpdatePath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformPasswordUpdatePath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformDeviceCreation( func (h *httpUserInternalAPI) PerformDeviceCreation(
@ -120,11 +116,10 @@ func (h *httpUserInternalAPI) PerformDeviceCreation(
request *api.PerformDeviceCreationRequest, request *api.PerformDeviceCreationRequest,
response *api.PerformDeviceCreationResponse, response *api.PerformDeviceCreationResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceCreation") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformDeviceCreation", h.apiURL+PerformDeviceCreationPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformDeviceCreationPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformDeviceDeletion( func (h *httpUserInternalAPI) PerformDeviceDeletion(
@ -132,47 +127,54 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion(
request *api.PerformDeviceDeletionRequest, request *api.PerformDeviceDeletionRequest,
response *api.PerformDeviceDeletionResponse, response *api.PerformDeviceDeletionResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformDeviceDeletion", h.apiURL+PerformDeviceDeletionPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformDeviceDeletionPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformLastSeenUpdate( func (h *httpUserInternalAPI) PerformLastSeenUpdate(
ctx context.Context, ctx context.Context,
req *api.PerformLastSeenUpdateRequest, request *api.PerformLastSeenUpdateRequest,
res *api.PerformLastSeenUpdateResponse, response *api.PerformLastSeenUpdateResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLastSeen") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformLastSeen", h.apiURL+PerformLastSeenUpdatePath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformLastSeenUpdatePath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { func (h *httpUserInternalAPI) PerformDeviceUpdate(
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") ctx context.Context,
defer span.Finish() request *api.PerformDeviceUpdateRequest,
response *api.PerformDeviceUpdateResponse,
apiURL := h.apiURL + PerformDeviceUpdatePath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"PerformDeviceUpdate", h.apiURL+PerformDeviceUpdatePath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { func (h *httpUserInternalAPI) PerformAccountDeactivation(
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountDeactivation") ctx context.Context,
defer span.Finish() request *api.PerformAccountDeactivationRequest,
response *api.PerformAccountDeactivationResponse,
apiURL := h.apiURL + PerformAccountDeactivationPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"PerformAccountDeactivation", h.apiURL+PerformAccountDeactivationPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error { func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation") ctx context.Context,
defer span.Finish() request *api.PerformOpenIDTokenCreationRequest,
response *api.PerformOpenIDTokenCreationResponse,
apiURL := h.apiURL + PerformOpenIDTokenCreationPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.CallInternalRPCAPI(
"PerformOpenIDTokenCreation", h.apiURL+PerformOpenIDTokenCreationPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QueryProfile( func (h *httpUserInternalAPI) QueryProfile(
@ -180,11 +182,10 @@ func (h *httpUserInternalAPI) QueryProfile(
request *api.QueryProfileRequest, request *api.QueryProfileRequest,
response *api.QueryProfileResponse, response *api.QueryProfileResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryProfile") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryProfile", h.apiURL+QueryProfilePath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QueryProfilePath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) QueryDeviceInfos( func (h *httpUserInternalAPI) QueryDeviceInfos(
@ -192,11 +193,10 @@ func (h *httpUserInternalAPI) QueryDeviceInfos(
request *api.QueryDeviceInfosRequest, request *api.QueryDeviceInfosRequest,
response *api.QueryDeviceInfosResponse, response *api.QueryDeviceInfosResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryDeviceInfos", h.apiURL+QueryDeviceInfosPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QueryDeviceInfosPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) QueryAccessToken( func (h *httpUserInternalAPI) QueryAccessToken(
@ -204,72 +204,87 @@ func (h *httpUserInternalAPI) QueryAccessToken(
request *api.QueryAccessTokenRequest, request *api.QueryAccessTokenRequest,
response *api.QueryAccessTokenResponse, response *api.QueryAccessTokenResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccessToken") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryAccessToken", h.apiURL+QueryAccessTokenPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QueryAccessTokenPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { func (h *httpUserInternalAPI) QueryDevices(
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDevices") ctx context.Context,
defer span.Finish() request *api.QueryDevicesRequest,
response *api.QueryDevicesResponse,
apiURL := h.apiURL + QueryDevicesPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryDevices", h.apiURL+QueryDevicesPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { func (h *httpUserInternalAPI) QueryAccountData(
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccountData") ctx context.Context,
defer span.Finish() request *api.QueryAccountDataRequest,
response *api.QueryAccountDataResponse,
apiURL := h.apiURL + QueryAccountDataPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryAccountData", h.apiURL+QueryAccountDataPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { func (h *httpUserInternalAPI) QuerySearchProfiles(
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles") ctx context.Context,
defer span.Finish() request *api.QuerySearchProfilesRequest,
response *api.QuerySearchProfilesResponse,
apiURL := h.apiURL + QuerySearchProfilesPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QuerySearchProfiles", h.apiURL+QuerySearchProfilesPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { func (h *httpUserInternalAPI) QueryOpenIDToken(
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken") ctx context.Context,
defer span.Finish() request *api.QueryOpenIDTokenRequest,
response *api.QueryOpenIDTokenResponse,
apiURL := h.apiURL + QueryOpenIDTokenPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryOpenIDToken", h.apiURL+QueryOpenIDTokenPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { func (h *httpUserInternalAPI) PerformKeyBackup(
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup") ctx context.Context,
defer span.Finish() request *api.PerformKeyBackupRequest,
response *api.PerformKeyBackupResponse,
apiURL := h.apiURL + PerformKeyBackupPath ) error {
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
if err != nil { "PerformKeyBackup", h.apiURL+PerformKeyBackupPath,
res.Error = err.Error() h.httpClient, ctx, request, response,
} )
return nil
}
func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
defer span.Finish()
apiURL := h.apiURL + QueryKeyBackupPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = err.Error()
}
} }
func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { func (h *httpUserInternalAPI) QueryKeyBackup(
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications") ctx context.Context,
defer span.Finish() request *api.QueryKeyBackupRequest,
response *api.QueryKeyBackupResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryKeyBackup", h.apiURL+QueryKeyBackupPath,
h.httpClient, ctx, request, response,
)
}
return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res) func (h *httpUserInternalAPI) QueryNotifications(
ctx context.Context,
request *api.QueryNotificationsRequest,
response *api.QueryNotificationsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryNotifications", h.apiURL+QueryNotificationsPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) PerformPusherSet( func (h *httpUserInternalAPI) PerformPusherSet(
@ -277,27 +292,32 @@ func (h *httpUserInternalAPI) PerformPusherSet(
request *api.PerformPusherSetRequest, request *api.PerformPusherSetRequest,
response *struct{}, response *struct{},
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformPusherSet", h.apiURL+PerformPusherSetPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformPusherSetPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { func (h *httpUserInternalAPI) PerformPusherDeletion(
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion") ctx context.Context,
defer span.Finish() request *api.PerformPusherDeletionRequest,
response *struct{},
apiURL := h.apiURL + PerformPusherDeletionPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"PerformPusherDeletion", h.apiURL+PerformPusherDeletionPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { func (h *httpUserInternalAPI) QueryPushers(
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers") ctx context.Context,
defer span.Finish() request *api.QueryPushersRequest,
response *api.QueryPushersResponse,
apiURL := h.apiURL + QueryPushersPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryPushers", h.apiURL+QueryPushersPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) PerformPushRulesPut( func (h *httpUserInternalAPI) PerformPushRulesPut(
@ -305,89 +325,117 @@ func (h *httpUserInternalAPI) PerformPushRulesPut(
request *api.PerformPushRulesPutRequest, request *api.PerformPushRulesPutRequest,
response *struct{}, response *struct{},
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformPushRulesPut", h.apiURL+PerformPushRulesPutPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformPushRulesPutPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { func (h *httpUserInternalAPI) QueryPushRules(
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules") ctx context.Context,
defer span.Finish() request *api.QueryPushRulesRequest,
response *api.QueryPushRulesResponse,
apiURL := h.apiURL + QueryPushRulesPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryPushRules", h.apiURL+QueryPushRulesPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { func (h *httpUserInternalAPI) SetAvatarURL(
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetAvatarURLPath) ctx context.Context,
defer span.Finish() request *api.PerformSetAvatarURLRequest,
response *api.PerformSetAvatarURLResponse,
apiURL := h.apiURL + PerformSetAvatarURLPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"SetAvatarURL", h.apiURL+PerformSetAvatarURLPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { func (h *httpUserInternalAPI) QueryNumericLocalpart(
span, ctx := opentracing.StartSpanFromContext(ctx, QueryNumericLocalpartPath) ctx context.Context,
defer span.Finish() response *api.QueryNumericLocalpartResponse,
) error {
apiURL := h.apiURL + QueryNumericLocalpartPath return httputil.CallInternalRPCAPI(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, struct{}{}, res) "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
h.httpClient, ctx, &struct{}{}, response,
)
} }
func (h *httpUserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { func (h *httpUserInternalAPI) QueryAccountAvailability(
span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountAvailabilityPath) ctx context.Context,
defer span.Finish() request *api.QueryAccountAvailabilityRequest,
response *api.QueryAccountAvailabilityResponse,
apiURL := h.apiURL + QueryAccountAvailabilityPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryAccountAvailability", h.apiURL+QueryAccountAvailabilityPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { func (h *httpUserInternalAPI) QueryAccountByPassword(
span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountByPasswordPath) ctx context.Context,
defer span.Finish() request *api.QueryAccountByPasswordRequest,
response *api.QueryAccountByPasswordResponse,
apiURL := h.apiURL + QueryAccountByPasswordPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryAccountByPassword", h.apiURL+QueryAccountByPasswordPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *struct{}) error { func (h *httpUserInternalAPI) SetDisplayName(
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetDisplayNamePath) ctx context.Context,
defer span.Finish() request *api.PerformUpdateDisplayNameRequest,
response *struct{},
apiURL := h.apiURL + PerformSetDisplayNamePath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { func (h *httpUserInternalAPI) QueryLocalpartForThreePID(
span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForThreePIDPath) ctx context.Context,
defer span.Finish() request *api.QueryLocalpartForThreePIDRequest,
response *api.QueryLocalpartForThreePIDResponse,
apiURL := h.apiURL + QueryLocalpartForThreePIDPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryLocalpartForThreePID", h.apiURL+QueryLocalpartForThreePIDPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(
span, ctx := opentracing.StartSpanFromContext(ctx, QueryThreePIDsForLocalpartPath) ctx context.Context,
defer span.Finish() request *api.QueryThreePIDsForLocalpartRequest,
response *api.QueryThreePIDsForLocalpartResponse,
apiURL := h.apiURL + QueryThreePIDsForLocalpartPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"QueryThreePIDsForLocalpart", h.apiURL+QueryThreePIDsForLocalpartPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error { func (h *httpUserInternalAPI) PerformForgetThreePID(
span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetThreePIDPath) ctx context.Context,
defer span.Finish() request *api.PerformForgetThreePIDRequest,
response *struct{},
apiURL := h.apiURL + PerformForgetThreePIDPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"PerformForgetThreePID", h.apiURL+PerformForgetThreePIDPath,
h.httpClient, ctx, request, response,
)
} }
func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveThreePIDAssociationPath) ctx context.Context,
defer span.Finish() request *api.PerformSaveThreePIDAssociationRequest,
response *struct{},
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath ) error {
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.CallInternalRPCAPI(
"PerformSaveThreePIDAssociation", h.apiURL+PerformSaveThreePIDAssociationPath,
h.httpClient, ctx, request, response,
)
} }

View file

@ -19,7 +19,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
) )
const ( const (
@ -33,11 +32,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenCreation(
request *api.PerformLoginTokenCreationRequest, request *api.PerformLoginTokenCreationRequest,
response *api.PerformLoginTokenCreationResponse, response *api.PerformLoginTokenCreationResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformLoginTokenCreation", h.apiURL+PerformLoginTokenCreationPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformLoginTokenCreationPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformLoginTokenDeletion( func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
@ -45,11 +43,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
request *api.PerformLoginTokenDeletionRequest, request *api.PerformLoginTokenDeletionRequest,
response *api.PerformLoginTokenDeletionResponse, response *api.PerformLoginTokenDeletionResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion") return httputil.CallInternalRPCAPI(
defer span.Finish() "PerformLoginTokenDeletion", h.apiURL+PerformLoginTokenDeletionPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + PerformLoginTokenDeletionPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) QueryLoginToken( func (h *httpUserInternalAPI) QueryLoginToken(
@ -57,9 +54,8 @@ func (h *httpUserInternalAPI) QueryLoginToken(
request *api.QueryLoginTokenRequest, request *api.QueryLoginTokenRequest,
response *api.QueryLoginTokenResponse, response *api.QueryLoginTokenResponse,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken") return httputil.CallInternalRPCAPI(
defer span.Finish() "QueryLoginToken", h.apiURL+QueryLoginTokenPath,
h.httpClient, ctx, request, response,
apiURL := h.apiURL + QueryLoginTokenPath )
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }

View file

@ -15,8 +15,6 @@
package inthttp package inthttp
import ( import (
"encoding/json"
"fmt"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -29,339 +27,134 @@ import (
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
addRoutesLoginToken(internalAPIMux, s) addRoutesLoginToken(internalAPIMux, s)
internalAPIMux.Handle(PerformAccountCreationPath, internalAPIMux.Handle(
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { PerformAccountCreationPath,
request := api.PerformAccountCreationRequest{} httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", s.PerformAccountCreation),
response := api.PerformAccountCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformAccountCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPasswordUpdatePath,
httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformPasswordUpdateRequest{}
response := api.PerformPasswordUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{}
response := api.PerformDeviceCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformLastSeenUpdatePath,
httputil.MakeInternalAPI("performLastSeenUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformLastSeenUpdateRequest{}
response := api.PerformLastSeenUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformLastSeenUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceUpdatePath,
httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceUpdateRequest{}
response := api.PerformDeviceUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceDeletionPath,
httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceDeletionRequest{}
response := api.PerformDeviceDeletionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformAccountDeactivationPath,
httputil.MakeInternalAPI("performAccountDeactivation", func(req *http.Request) util.JSONResponse {
request := api.PerformAccountDeactivationRequest{}
response := api.PerformAccountDeactivationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformAccountDeactivation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformOpenIDTokenCreationPath,
httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformOpenIDTokenCreationRequest{}
response := api.PerformOpenIDTokenCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryProfilePath,
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
request := api.QueryProfileRequest{}
response := api.QueryProfileResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryProfile(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryAccessTokenPath,
httputil.MakeInternalAPI("queryAccessToken", func(req *http.Request) util.JSONResponse {
request := api.QueryAccessTokenRequest{}
response := api.QueryAccessTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccessToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryDevicesPath,
httputil.MakeInternalAPI("queryDevices", func(req *http.Request) util.JSONResponse {
request := api.QueryDevicesRequest{}
response := api.QueryDevicesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryDevices(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryAccountDataPath,
httputil.MakeInternalAPI("queryAccountData", func(req *http.Request) util.JSONResponse {
request := api.QueryAccountDataRequest{}
response := api.QueryAccountDataResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccountData(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryDeviceInfosPath,
httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse {
request := api.QueryDeviceInfosRequest{}
response := api.QueryDeviceInfosResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QuerySearchProfilesPath,
httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse {
request := api.QuerySearchProfilesRequest{}
response := api.QuerySearchProfilesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryOpenIDTokenPath,
httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse {
request := api.QueryOpenIDTokenRequest{}
response := api.QueryOpenIDTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(InputAccountDataPath,
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
request := api.InputAccountDataRequest{}
response := api.InputAccountDataResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.InputAccountData(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryKeyBackupPath,
httputil.MakeInternalAPI("queryKeyBackup", func(req *http.Request) util.JSONResponse {
request := api.QueryKeyBackupRequest{}
response := api.QueryKeyBackupResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeyBackup(req.Context(), &request, &response)
if response.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", response.Error))
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformKeyBackupPath,
httputil.MakeInternalAPI("performKeyBackup", func(req *http.Request) util.JSONResponse {
request := api.PerformKeyBackupRequest{}
response := api.PerformKeyBackupResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.PerformKeyBackup(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryNotificationsPath,
httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse {
var request api.QueryNotificationsRequest
var response api.QueryNotificationsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryNotifications(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformPusherSetPath, internalAPIMux.Handle(
httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse { PerformPasswordUpdatePath,
request := api.PerformPusherSetRequest{} httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", s.PerformPasswordUpdate),
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPusherDeletionPath,
httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformPusherDeletionRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryPushersPath, internalAPIMux.Handle(
httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse { PerformDeviceCreationPath,
request := api.QueryPushersRequest{} httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", s.PerformDeviceCreation),
response := api.QueryPushersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryPushers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformPushRulesPutPath, internalAPIMux.Handle(
httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse { PerformLastSeenUpdatePath,
request := api.PerformPushRulesPutRequest{} httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", s.PerformLastSeenUpdate),
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryPushRulesPath, internalAPIMux.Handle(
httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse { PerformDeviceUpdatePath,
request := api.QueryPushRulesRequest{} httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", s.PerformDeviceUpdate),
response := api.QueryPushRulesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryPushRules(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformSetAvatarURLPath,
httputil.MakeInternalAPI("performSetAvatarURL", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformSetAvatarURLRequest{} PerformDeviceDeletionPath,
response := api.PerformSetAvatarURLResponse{} httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", s.PerformDeviceDeletion),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.SetAvatarURL(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(
PerformAccountDeactivationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", s.PerformAccountDeactivation),
)
internalAPIMux.Handle(
PerformOpenIDTokenCreationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", s.PerformOpenIDTokenCreation),
)
internalAPIMux.Handle(
QueryProfilePath,
httputil.MakeInternalRPCAPI("UserAPIQueryProfile", s.QueryProfile),
)
internalAPIMux.Handle(
QueryAccessTokenPath,
httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", s.QueryAccessToken),
)
internalAPIMux.Handle(
QueryDevicesPath,
httputil.MakeInternalRPCAPI("UserAPIQueryDevices", s.QueryDevices),
)
internalAPIMux.Handle(
QueryAccountDataPath,
httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", s.QueryAccountData),
)
internalAPIMux.Handle(
QueryDeviceInfosPath,
httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", s.QueryDeviceInfos),
)
internalAPIMux.Handle(
QuerySearchProfilesPath,
httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", s.QuerySearchProfiles),
)
internalAPIMux.Handle(
QueryOpenIDTokenPath,
httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", s.QueryOpenIDToken),
)
internalAPIMux.Handle(
InputAccountDataPath,
httputil.MakeInternalRPCAPI("UserAPIInputAccountData", s.InputAccountData),
)
internalAPIMux.Handle(
QueryKeyBackupPath,
httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", s.QueryKeyBackup),
)
internalAPIMux.Handle(
PerformKeyBackupPath,
httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", s.PerformKeyBackup),
)
internalAPIMux.Handle(
QueryNotificationsPath,
httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", s.QueryNotifications),
)
internalAPIMux.Handle(
PerformPusherSetPath,
httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", s.PerformPusherSet),
)
internalAPIMux.Handle(
PerformPusherDeletionPath,
httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", s.PerformPusherDeletion),
)
internalAPIMux.Handle(
QueryPushersPath,
httputil.MakeInternalRPCAPI("UserAPIQueryPushers", s.QueryPushers),
)
internalAPIMux.Handle(
PerformPushRulesPutPath,
httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", s.PerformPushRulesPut),
)
internalAPIMux.Handle(
QueryPushRulesPath,
httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", s.QueryPushRules),
)
internalAPIMux.Handle(
PerformSetAvatarURLPath,
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
)
// TODO: Look at the shape of this
internalAPIMux.Handle(QueryNumericLocalpartPath, internalAPIMux.Handle(QueryNumericLocalpartPath,
httputil.MakeInternalAPI("queryNumericLocalpart", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
response := api.QueryNumericLocalpartResponse{} response := api.QueryNumericLocalpartResponse{}
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil { if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -369,92 +162,39 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryAccountAvailabilityPath,
httputil.MakeInternalAPI("queryAccountAvailability", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryAccountAvailabilityRequest{} QueryAccountAvailabilityPath,
response := api.QueryAccountAvailabilityResponse{} httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", s.QueryAccountAvailability),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccountAvailability(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryAccountByPasswordPath,
httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryAccountByPasswordRequest{} QueryAccountByPasswordPath,
response := api.QueryAccountByPasswordResponse{} httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", s.QueryAccountByPassword),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccountByPassword(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformSetDisplayNamePath,
httputil.MakeInternalAPI("performSetDisplayName", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformUpdateDisplayNameRequest{} PerformSetDisplayNamePath,
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { httputil.MakeInternalRPCAPI("UserAPISetDisplayName", s.SetDisplayName),
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.SetDisplayName(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
) )
internalAPIMux.Handle(QueryLocalpartForThreePIDPath,
httputil.MakeInternalAPI("queryLocalpartForThreePID", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryLocalpartForThreePIDRequest{} QueryLocalpartForThreePIDPath,
response := api.QueryLocalpartForThreePIDResponse{} httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", s.QueryLocalpartForThreePID),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryLocalpartForThreePID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryThreePIDsForLocalpartPath,
httputil.MakeInternalAPI("queryThreePIDsForLocalpart", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryThreePIDsForLocalpartRequest{} QueryThreePIDsForLocalpartPath,
response := api.QueryThreePIDsForLocalpartResponse{} httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", s.QueryThreePIDsForLocalpart),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryThreePIDsForLocalpart(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformForgetThreePIDPath,
httputil.MakeInternalAPI("performForgetThreePID", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformForgetThreePIDRequest{} PerformForgetThreePIDPath,
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", s.PerformForgetThreePID),
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformForgetThreePID(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
) )
internalAPIMux.Handle(PerformSaveThreePIDAssociationPath,
httputil.MakeInternalAPI("performSaveThreePIDAssociation", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformSaveThreePIDAssociationRequest{} PerformSaveThreePIDAssociationPath,
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", s.PerformSaveThreePIDAssociation),
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformSaveThreePIDAssociation(req.Context(), &request, &struct{}{}); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
}),
) )
} }

View file

@ -15,54 +15,25 @@
package inthttp package inthttp
import ( import (
"encoding/json"
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
) )
// addRoutesLoginToken adds routes for all login token API calls. // addRoutesLoginToken adds routes for all login token API calls.
func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) {
internalAPIMux.Handle(PerformLoginTokenCreationPath, internalAPIMux.Handle(
httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse { PerformLoginTokenCreationPath,
request := api.PerformLoginTokenCreationRequest{} httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", s.PerformLoginTokenCreation),
response := api.PerformLoginTokenCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(PerformLoginTokenDeletionPath,
httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.PerformLoginTokenDeletionRequest{} PerformLoginTokenDeletionPath,
response := api.PerformLoginTokenDeletionResponse{} httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", s.PerformLoginTokenDeletion),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle(QueryLoginTokenPath,
httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse { internalAPIMux.Handle(
request := api.QueryLoginTokenRequest{} QueryLoginTokenPath,
response := api.QueryLoginTokenResponse{} httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", s.QueryLoginToken),
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
} }

View file

@ -117,16 +117,20 @@ func TestQueryProfile(t *testing.T) {
}, },
} }
runCases := func(testAPI api.UserInternalAPI) { runCases := func(testAPI api.UserInternalAPI, http bool) {
mode := "monolith"
if http {
mode = "HTTP"
}
for _, tc := range testCases { for _, tc := range testCases {
var gotRes api.QueryProfileResponse var gotRes api.QueryProfileResponse
gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes) gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes)
if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil { if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil {
t.Errorf("QueryProfile error, got %s want %s", gotErr, tc.wantErr) t.Errorf("QueryProfile %s error, got %s want %s", mode, gotErr, tc.wantErr)
continue continue
} }
if !reflect.DeepEqual(tc.wantRes, gotRes) { if !reflect.DeepEqual(tc.wantRes, gotRes) {
t.Errorf("QueryProfile response got %+v want %+v", gotRes, tc.wantRes) t.Errorf("QueryProfile %s response got %+v want %+v", mode, gotRes, tc.wantRes)
} }
} }
} }
@ -140,10 +144,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create HTTP client") t.Fatalf("failed to create HTTP client")
} }
runCases(httpAPI) runCases(httpAPI, true)
}) })
t.Run("Monolith", func(t *testing.T) { t.Run("Monolith", func(t *testing.T) {
runCases(userAPI) runCases(userAPI, false)
}) })
} }