Merge branch 'main' into austin.ellis/fix-msc2946

This commit is contained in:
Neil Alexander 2022-08-12 09:20:11 +01:00 committed by GitHub
commit ab96cb486e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
103 changed files with 3071 additions and 2736 deletions

View file

@ -19,10 +19,10 @@ jobs:
runs-on: ubuntu-latest
if: ${{ false }} # disable for now
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: 1.18
@ -66,8 +66,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v3
with:
go-version: 1.18
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
uses: golangci/golangci-lint-action@v3
# run go test with different go versions
test:
@ -101,7 +105,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- uses: actions/cache@v3
@ -133,7 +137,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Install dependencies x86
@ -167,7 +171,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Install dependencies
@ -208,7 +212,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: "1.18"
- uses: actions/cache@v3
@ -233,7 +237,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: "1.18"
- uses: actions/cache@v3

View file

@ -7,7 +7,6 @@ import (
"github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@ -42,11 +41,10 @@ func (h *httpAppServiceQueryAPI) RoomAliasExists(
request *api.RoomAliasExistsRequest,
response *api.RoomAliasExistsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists")
defer span.Finish()
apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"RoomAliasExists", h.appserviceURL+AppServiceRoomAliasExistsPath,
h.httpClient, ctx, request, response,
)
}
// UserIDExists implements AppServiceQueryAPI
@ -55,9 +53,8 @@ func (h *httpAppServiceQueryAPI) UserIDExists(
request *api.UserIDExistsRequest,
response *api.UserIDExistsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceUserIDExists")
defer span.Finish()
apiURL := h.appserviceURL + AppServiceUserIDExistsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"UserIDExists", h.appserviceURL+AppServiceUserIDExistsPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -1,43 +1,20 @@
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/util"
)
// AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux.
func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
AppServiceRoomAliasExistsPath,
httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse {
var request api.RoomAliasExistsRequest
var response api.RoomAliasExistsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := a.RoomAliasExists(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists),
)
internalAPIMux.Handle(
AppServiceUserIDExistsPath,
httputil.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse {
var request api.UserIDExistsRequest
var response api.UserIDExistsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := a.UserIDExists(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists),
)
}

View file

@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"net/http"
"sync"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
@ -102,6 +103,7 @@ type userInteractiveFlow struct {
// the user already has a valid access token, but we want to double-check
// that it isn't stolen by re-authenticating them.
type UserInteractive struct {
sync.RWMutex
Flows []userInteractiveFlow
// Map of login type to implementation
Types map[string]Type
@ -128,6 +130,8 @@ func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI)
}
func (u *UserInteractive) IsSingleStageFlow(authType string) bool {
u.RLock()
defer u.RUnlock()
for _, f := range u.Flows {
if len(f.Stages) == 1 && f.Stages[0] == authType {
return true
@ -137,8 +141,10 @@ func (u *UserInteractive) IsSingleStageFlow(authType string) bool {
}
func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
u.Lock()
// TODO: Handle multi-stage flows
delete(u.Sessions, sessionID)
u.Unlock()
}
type Challenge struct {
@ -150,12 +156,17 @@ type Challenge struct {
}
// Challenge returns an HTTP 401 with the supported flows for authenticating
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
func (u *UserInteractive) challenge(sessionID string) *util.JSONResponse {
u.RLock()
completed := u.Sessions[sessionID]
flows := u.Flows
u.RUnlock()
return &util.JSONResponse{
Code: 401,
JSON: Challenge{
Completed: u.Sessions[sessionID],
Flows: u.Flows,
Completed: completed,
Flows: flows,
Session: sessionID,
Params: make(map[string]interface{}),
},
@ -170,8 +181,10 @@ func (u *UserInteractive) NewSession() *util.JSONResponse {
res := jsonerror.InternalServerError()
return &res
}
u.Lock()
u.Sessions[sessionID] = []string{}
return u.Challenge(sessionID)
u.Unlock()
return u.challenge(sessionID)
}
// ResponseWithChallenge mixes together a JSON body (e.g an error with errcode/message) with the
@ -184,7 +197,7 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter
return &ise
}
_ = json.Unmarshal(b, &mixedObjects)
challenge := u.Challenge(sessionID)
challenge := u.challenge(sessionID)
b, err = json.Marshal(challenge.JSON)
if err != nil {
ise := jsonerror.InternalServerError()
@ -213,7 +226,11 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *
// extract the type so we know which login type to use
authType := gjson.GetBytes(bodyBytes, "auth.type").Str
u.RLock()
loginType, ok := u.Types[authType]
u.RUnlock()
if !ok {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
@ -223,7 +240,12 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *
// retrieve the session
sessionID := gjson.GetBytes(bodyBytes, "auth.session").Str
if _, ok = u.Sessions[sessionID]; !ok {
u.RLock()
_, ok = u.Sessions[sessionID]
u.RUnlock()
if !ok {
// if the login type is part of a single stage flow then allow them to omit the session ID
if !u.IsSingleStageFlow(authType) {
return nil, &util.JSONResponse{

View file

@ -15,11 +15,13 @@
package jsonerror
import (
"context"
"fmt"
"net/http"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
// MatrixError represents the "standard error response" in Matrix.
@ -213,3 +215,15 @@ func NotTrusted(serverName string) *MatrixError {
Err: fmt.Sprintf("Untrusted server '%s'", serverName),
}
}
// InternalAPIError is returned when Dendrite failed to reach an internal API.
func InternalAPIError(ctx context.Context, err error) util.JSONResponse {
logrus.WithContext(ctx).WithError(err).Error("Error reaching an internal API")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: &MatrixError{
ErrCode: "M_INTERNAL_SERVER_ERROR",
Err: "Dendrite encountered an error reaching an internal API.",
},
}
}

View file

@ -30,13 +30,15 @@ func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserv
}
}
res := &roomserverAPI.PerformAdminEvacuateRoomResponse{}
rsAPI.PerformAdminEvacuateRoom(
if err := rsAPI.PerformAdminEvacuateRoom(
req.Context(),
&roomserverAPI.PerformAdminEvacuateRoomRequest{
RoomID: roomID,
},
res,
)
); err != nil {
return util.ErrorResponse(err)
}
if err := res.Error; err != nil {
return err.JSONResponse()
}
@ -67,13 +69,15 @@ func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserv
}
}
res := &roomserverAPI.PerformAdminEvacuateUserResponse{}
rsAPI.PerformAdminEvacuateUser(
if err := rsAPI.PerformAdminEvacuateUser(
req.Context(),
&roomserverAPI.PerformAdminEvacuateUserRequest{
UserID: userID,
},
res,
)
); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := res.Error; err != nil {
return err.JSONResponse()
}

View file

@ -556,10 +556,12 @@ func createRoom(
if r.Visibility == "public" {
// expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID,
Visibility: "public",
}, &pubRes)
}, &pubRes); err != nil {
return jsonerror.InternalAPIError(ctx, err)
}
if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")

View file

@ -302,10 +302,12 @@ func SetVisibility(
}
var publishRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID,
Visibility: v.Visibility,
}, &publishRes)
}, &publishRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse()

View file

@ -81,8 +81,9 @@ func JoinRoomByIDOrAlias(
done := make(chan util.JSONResponse, 1)
go func() {
defer close(done)
rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes)
if joinRes.Error != nil {
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
done <- jsonerror.InternalAPIError(req.Context(), err)
} else if joinRes.Error != nil {
done <- joinRes.Error.JSONResponse()
} else {
done <- util.JSONResponse{

View file

@ -91,10 +91,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID,
Version: version,
}, &queryResp)
}, &queryResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
}
@ -233,13 +235,15 @@ func GetBackupKeys(
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string,
) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID,
Version: version,
ReturnKeys: true,
KeysForRoomID: roomID,
KeysForSessionID: sessionID,
}, &queryResp)
}, &queryResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
}

View file

@ -72,7 +72,9 @@ func UploadCrossSigningDeviceKeys(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)
if err := keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := uploadRes.Error; err != nil {
switch {
@ -114,7 +116,9 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie
}
uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes)
if err := keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := uploadRes.Error; err != nil {
switch {

View file

@ -62,7 +62,9 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devi
}
var uploadRes api.PerformUploadKeysResponse
keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes)
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
return util.ErrorResponse(err)
}
if uploadRes.Error != nil {
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
return jsonerror.InternalServerError()
@ -107,12 +109,14 @@ func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devic
return *resErr
}
queryRes := api.QueryKeysResponse{}
keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
if err := keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
UserID: device.UserID,
UserToDevices: r.DeviceKeys,
Timeout: r.GetTimeout(),
// TODO: Token?
}, &queryRes)
}, &queryRes); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
@ -145,10 +149,12 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse {
return *resErr
}
claimRes := api.PerformClaimKeysResponse{}
keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
if err := keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: r.OneTimeKeys,
Timeout: r.GetTimeout(),
}, &claimRes)
}, &claimRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if claimRes.Error != nil {
util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys")
return jsonerror.InternalServerError()

View file

@ -17,6 +17,7 @@ package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -54,7 +55,9 @@ func PeekRoomByIDOrAlias(
}
// Ask the roomserver to perform the peek.
rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes)
if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil {
return util.ErrorResponse(err)
}
if peekRes.Error != nil {
return peekRes.Error.JSONResponse()
}
@ -89,7 +92,9 @@ func UnpeekRoomByID(
}
unpeekRes := roomserverAPI.PerformUnpeekResponse{}
rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes)
if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if unpeekRes.Error != nil {
return unpeekRes.Error.JSONResponse()
}

View file

@ -64,7 +64,9 @@ func UpgradeRoom(
}
upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{}
rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp)
if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if upgradeResp.Error != nil {
if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom {

View file

@ -110,7 +110,7 @@ type FederationClientError struct {
Blacklisted bool
}
func (e *FederationClientError) Error() string {
func (e FederationClientError) Error() string {
return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted)
}

View file

@ -15,6 +15,8 @@
package federationapi
import (
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
@ -167,5 +169,16 @@ func NewInternalAPI(
if err = presenceConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start presence consumer")
}
var cleanExpiredEDUs func()
cleanExpiredEDUs = func() {
logrus.Infof("Cleaning expired EDUs")
if err := federationDB.DeleteExpiredEDUs(base.Context()); err != nil {
logrus.WithError(err).Error("Failed to clean expired EDUs")
}
time.AfterFunc(time.Hour, cleanExpiredEDUs)
}
time.AfterFunc(time.Minute, cleanExpiredEDUs)
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing)
}

View file

@ -32,11 +32,12 @@ type fedRoomserverAPI struct {
}
// PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) error {
if f.inputRoomEvents == nil {
return
return nil
}
f.inputRoomEvents(ctx, req, res)
return nil
}
// keychange consumer calls this

View file

@ -10,7 +10,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP API
@ -48,7 +47,11 @@ func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client,
if httpClient == nil {
return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is <nil>")
}
return &httpFederationInternalAPI{federationSenderURL, httpClient, cache}, nil
return &httpFederationInternalAPI{
federationAPIURL: federationSenderURL,
httpClient: httpClient,
cache: cache,
}, nil
}
type httpFederationInternalAPI struct {
@ -63,11 +66,10 @@ func (h *httpFederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformLeaveRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformLeave", h.federationAPIURL+FederationAPIPerformLeaveRequestPath,
h.httpClient, ctx, request, response,
)
}
// Handle sending an invite to a remote server.
@ -76,11 +78,10 @@ func (h *httpFederationInternalAPI) PerformInvite(
request *api.PerformInviteRequest,
response *api.PerformInviteResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInviteRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformInviteRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformInvite", h.federationAPIURL+FederationAPIPerformInviteRequestPath,
h.httpClient, ctx, request, response,
)
}
// Handle starting a peek on a remote server.
@ -89,11 +90,10 @@ func (h *httpFederationInternalAPI) PerformOutboundPeek(
request *api.PerformOutboundPeekRequest,
response *api.PerformOutboundPeekResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOutboundPeekRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformOutboundPeekRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformOutboundPeek", h.federationAPIURL+FederationAPIPerformOutboundPeekRequestPath,
h.httpClient, ctx, request, response,
)
}
// QueryJoinedHostServerNamesInRoom implements FederationInternalAPI
@ -102,11 +102,10 @@ func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryJoinedHostServerNamesInRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryJoinedHostServerNamesInRoom", h.federationAPIURL+FederationAPIQueryJoinedHostServerNamesInRoomPath,
h.httpClient, ctx, request, response,
)
}
// Handle an instruction to make_join & send_join with a remote server.
@ -115,12 +114,10 @@ func (h *httpFederationInternalAPI) PerformJoin(
request *api.PerformJoinRequest,
response *api.PerformJoinResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformJoinRequestPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
if err := httputil.CallInternalRPCAPI(
"PerformJoinRequest", h.federationAPIURL+FederationAPIPerformJoinRequestPath,
h.httpClient, ctx, request, response,
); err != nil {
response.LastError = &gomatrix.HTTPError{
Message: err.Error(),
Code: 0,
@ -135,11 +132,10 @@ func (h *httpFederationInternalAPI) PerformDirectoryLookup(
request *api.PerformDirectoryLookupRequest,
response *api.PerformDirectoryLookupResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformDirectoryLookupRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformDirectoryLookup", h.federationAPIURL+FederationAPIPerformDirectoryLookupRequestPath,
h.httpClient, ctx, request, response,
)
}
// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to.
@ -148,101 +144,61 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU(
request *api.PerformBroadcastEDURequest,
response *api.PerformBroadcastEDUResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformBroadcastEDUPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformBroadcastEDU", h.federationAPIURL+FederationAPIPerformBroadcastEDUPath,
h.httpClient, ctx, request, response,
)
}
type getUserDevices struct {
S gomatrixserverlib.ServerName
UserID string
Res *gomatrixserverlib.RespUserDevices
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetUserDevices")
defer span.Finish()
var result gomatrixserverlib.RespUserDevices
request := getUserDevices{
S: s,
UserID: userID,
}
var response getUserDevices
apiURL := h.federationAPIURL + FederationAPIGetUserDevicesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError](
"GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient,
ctx, &getUserDevices{
S: s,
UserID: userID,
},
)
}
type claimKeys struct {
S gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string
Res *gomatrixserverlib.RespClaimKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "ClaimKeys")
defer span.Finish()
var result gomatrixserverlib.RespClaimKeys
request := claimKeys{
S: s,
OneTimeKeys: oneTimeKeys,
}
var response claimKeys
apiURL := h.federationAPIURL + FederationAPIClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError](
"ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient,
ctx, &claimKeys{
S: s,
OneTimeKeys: oneTimeKeys,
},
)
}
type queryKeys struct {
S gomatrixserverlib.ServerName
Keys map[string][]string
Res *gomatrixserverlib.RespQueryKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys")
defer span.Finish()
var result gomatrixserverlib.RespQueryKeys
request := queryKeys{
S: s,
Keys: keys,
}
var response queryKeys
apiURL := h.federationAPIURL + FederationAPIQueryKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError](
"QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient,
ctx, &queryKeys{
S: s,
Keys: keys,
},
)
}
type backfill struct {
@ -250,32 +206,20 @@ type backfill struct {
RoomID string
Limit int
EventIDs []string
Res *gomatrixserverlib.Transaction
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill")
defer span.Finish()
request := backfill{
S: s,
RoomID: roomID,
Limit: limit,
EventIDs: eventIDs,
}
var response backfill
apiURL := h.federationAPIURL + FederationAPIBackfillPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
if response.Err != nil {
return gomatrixserverlib.Transaction{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError](
"Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient,
ctx, &backfill{
S: s,
RoomID: roomID,
Limit: limit,
EventIDs: eventIDs,
},
)
}
type lookupState struct {
@ -283,63 +227,39 @@ type lookupState struct {
RoomID string
EventID string
RoomVersion gomatrixserverlib.RoomVersion
Res *gomatrixserverlib.RespState
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState")
defer span.Finish()
request := lookupState{
S: s,
RoomID: roomID,
EventID: eventID,
RoomVersion: roomVersion,
}
var response lookupState
apiURL := h.federationAPIURL + FederationAPILookupStatePath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespState{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespState{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError](
"LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient,
ctx, &lookupState{
S: s,
RoomID: roomID,
EventID: eventID,
RoomVersion: roomVersion,
},
)
}
type lookupStateIDs struct {
S gomatrixserverlib.ServerName
RoomID string
EventID string
Res *gomatrixserverlib.RespStateIDs
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs")
defer span.Finish()
request := lookupStateIDs{
S: s,
RoomID: roomID,
EventID: eventID,
}
var response lookupStateIDs
apiURL := h.federationAPIURL + FederationAPILookupStateIDsPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespStateIDs{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespStateIDs{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError](
"LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient,
ctx, &lookupStateIDs{
S: s,
RoomID: roomID,
EventID: eventID,
},
)
}
type lookupMissingEvents struct {
@ -347,64 +267,38 @@ type lookupMissingEvents struct {
RoomID string
Missing gomatrixserverlib.MissingEvents
RoomVersion gomatrixserverlib.RoomVersion
Res struct {
Events []gomatrixserverlib.RawJSON `json:"events"`
}
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupMissingEvents")
defer span.Finish()
request := lookupMissingEvents{
S: s,
RoomID: roomID,
Missing: missing,
RoomVersion: roomVersion,
}
apiURL := h.federationAPIURL + FederationAPILookupMissingEventsPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &request)
if err != nil {
return res, err
}
if request.Err != nil {
return res, request.Err
}
res.Events = request.Res.Events
return res, nil
return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError](
"LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient,
ctx, &lookupMissingEvents{
S: s,
RoomID: roomID,
Missing: missing,
RoomVersion: roomVersion,
},
)
}
type getEvent struct {
S gomatrixserverlib.ServerName
EventID string
Res *gomatrixserverlib.Transaction
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent")
defer span.Finish()
request := getEvent{
S: s,
EventID: eventID,
}
var response getEvent
apiURL := h.federationAPIURL + FederationAPIGetEventPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
if response.Err != nil {
return gomatrixserverlib.Transaction{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError](
"GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient,
ctx, &getEvent{
S: s,
EventID: eventID,
},
)
}
type getEventAuth struct {
@ -412,135 +306,86 @@ type getEventAuth struct {
RoomVersion gomatrixserverlib.RoomVersion
RoomID string
EventID string
Res *gomatrixserverlib.RespEventAuth
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (gomatrixserverlib.RespEventAuth, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEventAuth")
defer span.Finish()
request := getEventAuth{
S: s,
RoomVersion: roomVersion,
RoomID: roomID,
EventID: eventID,
}
var response getEventAuth
apiURL := h.federationAPIURL + FederationAPIGetEventAuthPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespEventAuth{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespEventAuth{}, response.Err
}
return *response.Res, nil
return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError](
"GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient,
ctx, &getEventAuth{
S: s,
RoomVersion: roomVersion,
RoomID: roomID,
EventID: eventID,
},
)
}
func (h *httpFederationInternalAPI) QueryServerKeys(
ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerKeys")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryServerKeysPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryServerKeys", h.federationAPIURL+FederationAPIQueryServerKeysPath,
h.httpClient, ctx, req, res,
)
}
type lookupServerKeys struct {
S gomatrixserverlib.ServerName
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
ServerKeys []gomatrixserverlib.ServerKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys")
defer span.Finish()
request := lookupServerKeys{
S: s,
KeyRequests: keyRequests,
}
var response lookupServerKeys
apiURL := h.federationAPIURL + FederationAPILookupServerKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
if response.Err != nil {
return []gomatrixserverlib.ServerKeys{}, response.Err
}
return response.ServerKeys, nil
return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError](
"LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient,
ctx, &lookupServerKeys{
S: s,
KeyRequests: keyRequests,
},
)
}
type eventRelationships struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion
Res gomatrixserverlib.MSC2836EventRelationshipsResponse
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships")
defer span.Finish()
request := eventRelationships{
S: s,
Req: r,
RoomVer: roomVersion,
}
var response eventRelationships
apiURL := h.federationAPIURL + FederationAPIEventRelationshipsPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient,
ctx, &eventRelationships{
S: s,
Req: r,
RoomVer: roomVersion,
},
)
}
type spacesReq struct {
S gomatrixserverlib.ServerName
SuggestedOnly bool
RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
defer span.Finish()
request := spacesReq{
S: dst,
SuggestedOnly: suggestedOnly,
RoomID: roomID,
}
var response spacesReq
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient,
ctx, &spacesReq{
S: dst,
SuggestedOnly: suggestedOnly,
RoomID: roomID,
},
)
}
func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing {
@ -614,11 +459,10 @@ func (h *httpFederationInternalAPI) InputPublicKeys(
request *api.InputPublicKeysRequest,
response *api.InputPublicKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIInputPublicKeyPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"InputPublicKey", h.federationAPIURL+FederationAPIInputPublicKeyPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) QueryPublicKeys(
@ -626,9 +470,8 @@ func (h *httpFederationInternalAPI) QueryPublicKeys(
request *api.QueryPublicKeysRequest,
response *api.QueryPublicKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryPublicKeyPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryPublicKeys", h.federationAPIURL+FederationAPIQueryPublicKeyPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -1,12 +1,14 @@
package inthttp
import (
"context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -15,372 +17,180 @@ import (
func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
FederationAPIQueryJoinedHostServerNamesInRoomPath,
httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryJoinedHostServerNamesInRoomRequest
var response api.QueryJoinedHostServerNamesInRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath,
httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformJoinRequest
var response api.PerformJoinResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
intAPI.PerformJoin(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformLeaveRequestPath,
httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformLeaveRequest
var response api.PerformLeaveResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom),
)
internalAPIMux.Handle(
FederationAPIPerformInviteRequestPath,
httputil.MakeInternalAPI("PerformInviteRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformInviteRequest
var response api.PerformInviteResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformInvite(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite),
)
internalAPIMux.Handle(
FederationAPIPerformLeaveRequestPath,
httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave),
)
internalAPIMux.Handle(
FederationAPIPerformDirectoryLookupRequestPath,
httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformDirectoryLookupRequest
var response api.PerformDirectoryLookupResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup),
)
internalAPIMux.Handle(
FederationAPIPerformBroadcastEDUPath,
httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse {
var request api.PerformBroadcastEDURequest
var response api.PerformBroadcastEDUResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
)
internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath,
httputil.MakeInternalRPCAPI(
"FederationAPIPerformJoinRequest",
func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error {
intAPI.PerformJoin(ctx, req, res)
return nil
},
),
)
internalAPIMux.Handle(
FederationAPIGetUserDevicesPath,
httputil.MakeInternalAPI("GetUserDevices", func(req *http.Request) util.JSONResponse {
var request getUserDevices
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetUserDevices(req.Context(), request.S, request.UserID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIGetUserDevices",
func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) {
res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIClaimKeysPath,
httputil.MakeInternalAPI("ClaimKeys", func(req *http.Request) util.JSONResponse {
var request claimKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.ClaimKeys(req.Context(), request.S, request.OneTimeKeys)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIClaimKeys",
func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) {
res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIQueryKeysPath,
httputil.MakeInternalAPI("QueryKeys", func(req *http.Request) util.JSONResponse {
var request queryKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.QueryKeys(req.Context(), request.S, request.Keys)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIQueryKeys",
func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) {
res, err := intAPI.QueryKeys(ctx, req.S, req.Keys)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIBackfillPath,
httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse {
var request backfill
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIBackfill",
func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPILookupStatePath,
httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse {
var request lookupState
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPILookupState",
func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) {
res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPILookupStateIDsPath,
httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse {
var request lookupStateIDs
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPILookupStateIDs",
func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) {
res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPILookupMissingEventsPath,
httputil.MakeInternalAPI("LookupMissingEvents", func(req *http.Request) util.JSONResponse {
var request lookupMissingEvents
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupMissingEvents(req.Context(), request.S, request.RoomID, request.Missing, request.RoomVersion)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
for _, event := range res.Events {
js, err := json.Marshal(event)
if err != nil {
return util.MessageResponse(http.StatusInternalServerError, err.Error())
}
request.Res.Events = append(request.Res.Events, js)
}
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPILookupMissingEvents",
func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) {
res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIGetEventPath,
httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse {
var request getEvent
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIGetEvent",
func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.GetEvent(ctx, req.S, req.EventID)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIGetEventAuthPath,
httputil.MakeInternalAPI("GetEventAuth", func(req *http.Request) util.JSONResponse {
var request getEventAuth
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetEventAuth(req.Context(), request.S, request.RoomVersion, request.RoomID, request.EventID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIGetEventAuth",
func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) {
res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIQueryServerKeysPath,
httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse {
var request api.QueryServerKeysRequest
var response api.QueryServerKeysResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryServerKeys(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys),
)
internalAPIMux.Handle(
FederationAPILookupServerKeysPath,
httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse {
var request lookupServerKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.ServerKeys = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPILookupServerKeys",
func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) {
res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPIEventRelationshipsPath,
httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse {
var request eventRelationships
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIMSC2836EventRelationships",
func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer)
return &res, federationClientError(err)
},
),
)
internalAPIMux.Handle(
FederationAPISpacesSummaryPath,
httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse {
var request spacesReq
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
httputil.MakeInternalProxyAPI(
"FederationAPIMSC2946SpacesSummary",
func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly)
return &res, federationClientError(err)
},
),
)
// TODO: Look at this shape
internalAPIMux.Handle(FederationAPIQueryPublicKeyPath,
httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse {
httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryPublicKeysRequest{}
response := api.QueryPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -394,8 +204,10 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
// TODO: Look at this shape
internalAPIMux.Handle(FederationAPIInputPublicKeyPath,
httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse {
httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.InputPublicKeysRequest{}
response := api.InputPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -408,3 +220,18 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
}),
)
}
func federationClientError(err error) error {
switch ferr := err.(type) {
case nil:
return nil
case api.FederationClientError:
return &ferr
case *api.FederationClientError:
return ferr
default:
return &api.FederationClientError{
Err: err.Error(),
}
}
}

View file

@ -127,6 +127,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
event.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return

View file

@ -30,9 +30,11 @@ func GetUserDevices(
userID string,
) util.JSONResponse {
var res keyapi.QueryDeviceMessagesResponse
keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
UserID: userID,
}, &res)
}, &res); err != nil {
return util.ErrorResponse(err)
}
if res.Error != nil {
util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed")
return jsonerror.InternalServerError()
@ -47,7 +49,9 @@ func GetUserDevices(
for _, dev := range res.Devices {
sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID))
}
keyAPI.QuerySignatures(req.Context(), sigReq, sigRes)
if err := keyAPI.QuerySignatures(req.Context(), sigReq, sigRes); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
response := gomatrixserverlib.RespUserDevices{
UserID: userID,

View file

@ -392,7 +392,7 @@ func SendJoin(
// the room, so set SendAsServer to cfg.Matrix.ServerName
if !alreadyJoined {
var response api.InputRoomEventsResponse
rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{
{
Kind: api.KindNew,
@ -401,7 +401,9 @@ func SendJoin(
TransactionID: nil,
},
},
}, &response)
}, &response); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed")
if response.NotAllowed {

View file

@ -19,7 +19,7 @@ import (
"net/http"
"time"
"github.com/matrix-org/dendrite/clientapi/httputil"
clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
@ -61,9 +61,11 @@ func QueryDeviceKeys(
}
var queryRes api.QueryKeysResponse
keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
if err := keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
UserToDevices: qkr.DeviceKeys,
}, &queryRes)
}, &queryRes); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if queryRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys")
return jsonerror.InternalServerError()
@ -113,9 +115,11 @@ func ClaimOneTimeKeys(
}
var claimRes api.PerformClaimKeysResponse
keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
if err := keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: cor.OneTimeKeys,
}, &claimRes)
}, &claimRes); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if claimRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys")
return jsonerror.InternalServerError()
@ -184,7 +188,7 @@ func NotaryKeys(
) util.JSONResponse {
if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
return *reqErr
}
}

View file

@ -277,7 +277,7 @@ func SendLeave(
// We are responsible for notifying other servers that the user has left
// the room, so set SendAsServer to cfg.Matrix.ServerName
var response api.InputRoomEventsResponse
rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{
{
Kind: api.KindNew,
@ -286,7 +286,9 @@ func SendLeave(
TransactionID: nil,
},
},
}, &response)
}, &response); err != nil {
return jsonerror.InternalAPIError(httpReq.Context(), err)
}
if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed")

View file

@ -458,7 +458,9 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli
UserID: updatePayload.UserID,
}
uploadRes := &keyapi.PerformUploadDeviceKeysResponse{}
t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
return err
}
if uploadRes.Error != nil {
return uploadRes.Error
}

View file

@ -64,11 +64,12 @@ func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) {
) error {
t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...)
for _, ire := range request.InputRoomEvents {
fmt.Println("InputRoomEvents: ", ire.Event.EventID())
}
return nil
}
// Query the latest events and state for a room from the room server.

View file

@ -16,6 +16,7 @@ package storage
import (
"context"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/types"
@ -38,7 +39,7 @@ type Database interface {
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
@ -70,4 +71,6 @@ type Database interface {
// Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs))
// such that the combination of all server keys will include all the `optKeyIDs`.
GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error)
// DeleteExpiredEDUs cleans up expired EDUs
DeleteExpiredEDUs(ctx context.Context) error
}

View file

@ -0,0 +1,44 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
)
func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus ADD COLUMN IF NOT EXISTS expires_at BIGINT NOT NULL DEFAULT 0;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
_, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24)))
if err != nil {
return fmt.Errorf("failed to update queue_edus: %w", err)
}
return nil
}
func DownAddexpiresat(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus DROP COLUMN expires_at;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -19,9 +19,11 @@ import (
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/federationapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueEDUsSchema = `
@ -31,7 +33,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table.
json_nid BIGINT NOT NULL
json_nid BIGINT NOT NULL,
-- The expiry time of this edu, if any.
expires_at BIGINT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
@ -43,8 +47,8 @@ CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx
`
const insertQueueEDUSQL = "" +
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid, expires_at)" +
" VALUES ($1, $2, $3, $4)"
const deleteQueueEDUSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid = ANY($2)"
@ -65,6 +69,12 @@ const selectQueueEDUCountSQL = "" +
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
const selectExpiredEDUsSQL = "" +
"SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1"
const deleteExpiredEDUsSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1"
type queueEDUsStatements struct {
db *sql.DB
insertQueueEDUStmt *sql.Stmt
@ -73,6 +83,8 @@ type queueEDUsStatements struct {
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
selectExpiredEDUsStmt *sql.Stmt
deleteExpiredEDUsStmt *sql.Stmt
}
func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
@ -81,27 +93,34 @@ func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
}
_, err = s.db.Exec(queueEDUsSchema)
if err != nil {
return
return s, err
}
if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil {
return
m := sqlutil.NewMigrator(db)
m.AddMigrations(
sqlutil.Migration{
Version: "federationapi: add expiresat column",
Up: deltas.UpAddexpiresat,
},
)
if err := m.Up(context.Background()); err != nil {
return s, err
}
if s.deleteQueueEDUStmt, err = s.db.Prepare(deleteQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil {
return
}
if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil {
return
}
return
return s, nil
}
func (s *queueEDUsStatements) Prepare() error {
return sqlutil.StatementList{
{&s.insertQueueEDUStmt, insertQueueEDUSQL},
{&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
}.Prepare(s.db)
}
func (s *queueEDUsStatements) InsertQueueEDU(
@ -110,6 +129,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
eduType string,
serverName gomatrixserverlib.ServerName,
nid int64,
expiresAt gomatrixserverlib.Timestamp,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
@ -117,6 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
eduType, // the EDU type
serverName, // destination server name
nid, // JSON blob NID
expiresAt, // timestamp of expiry
)
return err
}
@ -150,7 +171,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs(
}
result = append(result, nid)
}
return result, nil
return result, rows.Err()
}
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
@ -200,3 +221,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames(
return result, rows.Err()
}
func (s *queueEDUsStatements) SelectExpiredEDUs(
ctx context.Context, txn *sql.Tx,
expiredBefore gomatrixserverlib.Timestamp,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt)
rows, err := stmt.QueryContext(ctx, expiredBefore)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed")
var result []int64
var nid int64
for rows.Next() {
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *queueEDUsStatements) DeleteExpiredEDUs(
ctx context.Context, txn *sql.Tx,
expiredBefore gomatrixserverlib.Timestamp,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt)
_, err := stmt.ExecContext(ctx, expiredBefore)
return err
}

View file

@ -91,6 +91,9 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
if err = queueEDUs.Prepare(); err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
ServerName: serverName,

View file

@ -20,10 +20,21 @@ import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
)
// defaultExpiry for EDUs if not listed below
var defaultExpiry = time.Hour * 24
// defaultExpireEDUTypes contains EDUs which can/should be expired after a given time
// if the target server isn't reachable for some reason.
var defaultExpireEDUTypes = map[string]time.Duration{
gomatrixserverlib.MTyping: time.Minute,
gomatrixserverlib.MPresence: time.Minute * 10,
}
// AssociateEDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send
// to which servers.
@ -32,7 +43,21 @@ func (d *Database) AssociateEDUWithDestination(
serverName gomatrixserverlib.ServerName,
receipt *Receipt,
eduType string,
expireEDUTypes map[string]time.Duration,
) error {
if expireEDUTypes == nil {
expireEDUTypes = defaultExpireEDUTypes
}
expiresAt := gomatrixserverlib.AsTimestamp(time.Now().Add(defaultExpiry))
if duration, ok := expireEDUTypes[eduType]; ok {
// Keep EDUs for at least x minutes before deleting them
expiresAt = gomatrixserverlib.AsTimestamp(time.Now().Add(duration))
}
// We forcibly set m.direct_to_device events to 0, as we always want them
// to be delivered. (required for E2EE)
if eduType == gomatrixserverlib.MDirectToDevice {
expiresAt = 0
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context
@ -40,6 +65,7 @@ func (d *Database) AssociateEDUWithDestination(
eduType, // EDU type for coalescing
serverName, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
); err != nil {
return fmt.Errorf("InsertQueueEDU: %w", err)
}
@ -150,3 +176,26 @@ func (d *Database) GetPendingEDUServerNames(
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil)
}
// DeleteExpiredEDUs deletes expired EDUs
func (d *Database) DeleteExpiredEDUs(ctx context.Context) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
expiredBefore := gomatrixserverlib.AsTimestamp(time.Now())
jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore)
if err != nil {
return err
}
if len(jsonNIDs) == 0 {
return nil
}
for i := range jsonNIDs {
d.Cache.EvictFederationQueuedEDU(jsonNIDs[i])
}
if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil {
return err
}
return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore)
})
}

View file

@ -0,0 +1,68 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
)
func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus RENAME TO federationsender_queue_edus_old;")
if err != nil {
return fmt.Errorf("failed to rename table: %w", err)
}
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
edu_type TEXT NOT NULL,
server_name TEXT NOT NULL,
json_nid BIGINT NOT NULL,
expires_at BIGINT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name);
`)
if err != nil {
return fmt.Errorf("failed to create new table: %w", err)
}
_, err = tx.ExecContext(ctx, `
INSERT
INTO federationsender_queue_edus (
edu_type, server_name, json_nid, expires_at
) SELECT edu_type, server_name, json_nid, 0 FROM federationsender_queue_edus_old;
`)
if err != nil {
return fmt.Errorf("failed to update queue_edus: %w", err)
}
_, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24)))
if err != nil {
return fmt.Errorf("failed to update queue_edus: %w", err)
}
return nil
}
func DownAddexpiresat(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus DROP COLUMN expires_at;")
if err != nil {
return fmt.Errorf("failed to rename table: %w", err)
}
return nil
}

View file

@ -20,9 +20,11 @@ import (
"fmt"
"strings"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueEDUsSchema = `
@ -32,7 +34,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table.
json_nid BIGINT NOT NULL
json_nid BIGINT NOT NULL,
-- The expiry time of this edu, if any.
expires_at BIGINT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
@ -44,8 +48,8 @@ CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx
`
const insertQueueEDUSQL = "" +
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid, expires_at)" +
" VALUES ($1, $2, $3, $4)"
const deleteQueueEDUsSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
@ -66,13 +70,22 @@ const selectQueueEDUCountSQL = "" +
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
const selectExpiredEDUsSQL = "" +
"SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1"
const deleteExpiredEDUsSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1"
type queueEDUsStatements struct {
db *sql.DB
insertQueueEDUStmt *sql.Stmt
db *sql.DB
insertQueueEDUStmt *sql.Stmt
// deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
selectExpiredEDUsStmt *sql.Stmt
deleteExpiredEDUsStmt *sql.Stmt
}
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
@ -81,24 +94,33 @@ func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
}
_, err = db.Exec(queueEDUsSchema)
if err != nil {
return
return s, err
}
if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil {
return
m := sqlutil.NewMigrator(db)
m.AddMigrations(
sqlutil.Migration{
Version: "federationapi: add expiresat column",
Up: deltas.UpAddexpiresat,
},
)
if err := m.Up(context.Background()); err != nil {
return s, err
}
if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil {
return
}
if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
return
}
return
return s, nil
}
func (s *queueEDUsStatements) Prepare() error {
return sqlutil.StatementList{
{&s.insertQueueEDUStmt, insertQueueEDUSQL},
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
}.Prepare(s.db)
}
func (s *queueEDUsStatements) InsertQueueEDU(
@ -107,6 +129,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
eduType string,
serverName gomatrixserverlib.ServerName,
nid int64,
expiresAt gomatrixserverlib.Timestamp,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
@ -114,6 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
eduType, // the EDU type
serverName, // destination server name
nid, // JSON blob NID
expiresAt, // timestamp of expiry
)
return err
}
@ -159,7 +183,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs(
}
result = append(result, nid)
}
return result, nil
return result, rows.Err()
}
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
@ -209,3 +233,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames(
return result, rows.Err()
}
func (s *queueEDUsStatements) SelectExpiredEDUs(
ctx context.Context, txn *sql.Tx,
expiredBefore gomatrixserverlib.Timestamp,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt)
rows, err := stmt.QueryContext(ctx, expiredBefore)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed")
var result []int64
var nid int64
for rows.Next() {
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *queueEDUsStatements) DeleteExpiredEDUs(
ctx context.Context, txn *sql.Tx,
expiredBefore gomatrixserverlib.Timestamp,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt)
_, err := stmt.ExecContext(ctx, expiredBefore)
return err
}

View file

@ -90,6 +90,9 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
if err = queueEDUs.Prepare(); err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
ServerName: serverName,

View file

@ -0,0 +1,81 @@
package storage_test
import (
"context"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
)
func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
b, baseClose := testrig.CreateBaseDendrite(t, dbType)
connStr, dbClose := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewDatabase(b, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, b.Caches, b.Cfg.Global.ServerName)
if err != nil {
t.Fatalf("NewDatabase returned %s", err)
}
return db, func() {
dbClose()
baseClose()
}
}
func TestExpireEDUs(t *testing.T) {
var expireEDUTypes = map[string]time.Duration{
gomatrixserverlib.MReceipt: time.Millisecond,
}
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateFederationDatabase(t, dbType)
defer close()
// insert some data
for i := 0; i < 100; i++ {
receipt, err := db.StoreJSON(ctx, "{}")
assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
assert.NoError(t, err)
}
// add data without expiry
receipt, err := db.StoreJSON(ctx, "{}")
assert.NoError(t, err)
// m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes)
assert.NoError(t, err)
// Delete expired EDUs
err = db.DeleteExpiredEDUs(ctx)
assert.NoError(t, err)
// verify the data is gone
data, err := db.GetPendingEDUs(ctx, "localhost", 100)
assert.NoError(t, err)
assert.Equal(t, 1, len(data))
// check that m.direct_to_device is never expired
receipt, err = db.StoreJSON(ctx, "{}")
assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
assert.NoError(t, err)
err = db.DeleteExpiredEDUs(ctx)
assert.NoError(t, err)
// We should get two EDUs, the m.read_marker and the m.direct_to_device
data, err = db.GetPendingEDUs(ctx, "localhost", 100)
assert.NoError(t, err)
assert.Equal(t, 2, len(data))
})
}

View file

@ -34,12 +34,15 @@ type FederationQueuePDUs interface {
}
type FederationQueueEDUs interface {
InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64) error
InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64, expiresAt gomatrixserverlib.Timestamp) error
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error
Prepare() error
}
type FederationQueueJSON interface {

8
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771
github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.13
@ -34,7 +34,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.12.2
github.com/sirupsen/logrus v1.8.1
github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.7.1
github.com/tidwall/gjson v1.14.1
github.com/tidwall/sjson v1.2.4
@ -42,7 +42,7 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.3
go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
golang.org/x/mobile v0.0.0-20220518205345-8578da9835fd
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e
@ -99,7 +99,7 @@ require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
golang.org/x/sys v0.0.0-20220731174439-a90be440212d // indirect
golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect
golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect
golang.org/x/tools v0.1.10 // indirect

16
go.sum
View file

@ -343,8 +343,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 h1:ZIPHFIPNDS9dmEbPEiJbNmyCGJtn9exfpLC7JOcn/bE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839 h1:QEFxKWH8PlEt3ZQKl31yJNAm8lvpNUwT51IMNTl9v1k=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY=
github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -493,8 +493,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE=
@ -569,8 +569,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -748,8 +748,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=

View file

@ -146,7 +146,7 @@ func (c *RistrettoCostedCachePartition[K, V]) Set(key K, value V) {
}
type RistrettoCachePartition[K keyable, V any] struct {
cache *ristretto.Cache
cache *ristretto.Cache //nolint:all,unused
Prefix byte
Mutable bool
MaxAge time.Duration

View file

@ -19,19 +19,21 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/matrix-org/dendrite/userapi/api"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
// PostJSON performs a POST request with JSON on an internal HTTP API
func PostJSON(
// PostJSON performs a POST request with JSON on an internal HTTP API.
// The error will match the errtype if returned from the remote API, or
// will be a different type if there was a problem reaching the API.
func PostJSON[reqtype, restype any, errtype error](
ctx context.Context, span opentracing.Span, httpClient *http.Client,
apiURL string, request, response interface{},
apiURL string, request *reqtype, response *restype,
) error {
jsonBytes, err := json.Marshal(request)
if err != nil {
@ -69,17 +71,23 @@ func PostJSON(
if err != nil {
return err
}
if res.StatusCode != http.StatusOK {
var errorBody struct {
Message string `json:"message"`
}
if _, ok := response.(*api.PerformKeyBackupResponse); ok { // TODO: remove this, once cross-boundary errors are a thing
return nil
}
if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil {
return fmt.Errorf("internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message)
}
return fmt.Errorf("internal API: %d from %s", res.StatusCode, apiURL)
var body []byte
body, err = io.ReadAll(res.Body)
if err != nil {
return err
}
return json.NewDecoder(res.Body).Decode(response)
if res.StatusCode != http.StatusOK {
if len(body) == 0 {
return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL)
}
var reserr errtype
if err = json.Unmarshal(body, reserr); err != nil {
return fmt.Errorf("HTTP %d from %s", res.StatusCode, apiURL)
}
return reserr
}
if err = json.Unmarshal(body, response); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
return nil
}

View file

@ -0,0 +1,93 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httputil
import (
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
type InternalAPIError struct {
Type string
Message string
}
func (e InternalAPIError) Error() string {
return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message)
}
func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler {
return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse {
var request reqtype
var response restype
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := f(req.Context(), &request, &response); err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: &InternalAPIError{
Type: reflect.TypeOf(err).String(),
Message: fmt.Sprintf("%s", err),
},
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: &response,
}
})
}
func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler {
return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse {
var request reqtype
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
response, err := f(req.Context(), &request)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: err,
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: response,
}
})
}
func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Client, ctx context.Context, request *reqtype, response *restype) error {
span, ctx := opentracing.StartSpanFromContext(ctx, name)
defer span.Finish()
return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response)
}
func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, name)
defer span.Finish()
var response restype
return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response)
}

View file

@ -38,32 +38,32 @@ type KeyInternalAPI interface {
// API functions required by the clientapi
type ClientKeyAPI interface {
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse)
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// API functions required by the userapi
type UserKeyAPI interface {
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse)
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error
}
// API functions required by the syncapi
type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse)
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
}
type FederationKeyAPI interface {
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse)
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse)
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse)
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// KeyError is returned if there was a problem performing/querying the server

View file

@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos
}
// nolint:gocyclo
func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
// Find the keys to store.
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
toStore := types.CrossSigningKeyMap{}
@ -115,7 +115,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "Master key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
@ -131,7 +131,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "Self-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
@ -146,7 +146,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "User-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
@ -161,7 +161,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "No keys were supplied in the request",
IsMissingParam: true,
}
return
return nil
}
// We can't have a self-signing or user-signing key without a master
@ -174,7 +174,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
}
return
return nil
}
// If we still can't find a master key for the user then stop the upload.
@ -185,7 +185,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "No master key was found",
IsMissingParam: true,
}
return
return nil
}
}
@ -212,7 +212,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
}
}
if !changed {
return
return nil
}
// Store the keys.
@ -220,7 +220,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
}
return
return nil
}
// Now upload any signatures that were included with the keys.
@ -238,7 +238,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
}
return
return nil
}
}
}
@ -255,17 +255,18 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
update.SelfSigningKey = &ssk
}
if update.MasterKey == nil && update.SelfSigningKey == nil {
return
return nil
}
if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
return
return nil
}
return nil
}
func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) {
func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
// Before we do anything, we need the master and self-signing keys for this user.
// Then we can verify the signatures make sense.
queryReq := &api.QueryKeysRequest{
@ -276,7 +277,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
for userID := range req.Signatures {
queryReq.UserToDevices[userID] = []string{}
}
a.QueryKeys(ctx, queryReq, queryRes)
_ = a.QueryKeys(ctx, queryReq, queryRes)
selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
@ -322,14 +323,14 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
}
return
return nil
}
if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
}
return
return nil
}
// Finally, generate a notification that we updated the signatures.
@ -345,9 +346,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
return
return nil
}
}
return nil
}
func (a *KeyInternalAPI) processSelfSignatures(
@ -520,7 +522,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
}
}
func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) {
func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
for targetUserID, forTargetUser := range req.TargetIDs {
keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil && err != sql.ErrNoRows {
@ -559,7 +561,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
}
return
return nil
}
for sourceUserID, forSourceUser := range sigMap {
@ -581,4 +583,5 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
}
}
}
return nil
}

View file

@ -119,7 +119,7 @@ type DeviceListUpdaterDatabase interface {
}
type DeviceListUpdaterAPI interface {
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse)
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error
}
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
@ -421,7 +421,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
uploadReq.SelfSigningKey = *res.SelfSigningKey
}
}
u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
_ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
}
err = u.updateDeviceList(&res)
if err != nil {

View file

@ -113,8 +113,8 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys
type mockDeviceListUpdaterAPI struct {
}
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
return nil
}
type roundTripper struct {

View file

@ -48,18 +48,20 @@ func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
a.UserAPI = i
}
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
if err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
return nil
}
res.Offset = latest
res.UserIDs = userIDs
return nil
}
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
@ -67,9 +69,10 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
}
return nil
}
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
// wrap request map in a top-level by-domain map
@ -113,6 +116,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
}
return nil
}
func (a *KeyInternalAPI) claimRemoteKeys(
@ -172,32 +176,34 @@ func (a *KeyInternalAPI) claimRemoteKeys(
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
}
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) {
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
}
}
return nil
}
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
}
return
return nil
}
res.Count = *count
return nil
}
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
}
return
return nil
}
maxStreamID := int64(0)
for _, m := range msgs {
@ -215,10 +221,11 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
}
res.Devices = result
res.StreamID = maxStreamID
return nil
}
// nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@ -244,7 +251,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
}
return
return nil
}
// pull out display names after we have the keys so we handle wildcards correctly
@ -318,7 +325,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
return nil
}
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
@ -344,7 +351,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
return nil
}
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
@ -372,6 +379,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
}
}
return nil
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@ -68,168 +67,108 @@ func (h *httpKeyInternalAPI) PerformClaimKeys(
ctx context.Context,
request *api.PerformClaimKeysRequest,
response *api.PerformClaimKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys")
defer span.Finish()
apiURL := h.apiURL + PerformClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformClaimKeys", h.apiURL+PerformClaimKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) PerformDeleteKeys(
ctx context.Context,
request *api.PerformDeleteKeysRequest,
response *api.PerformDeleteKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys")
defer span.Finish()
apiURL := h.apiURL + PerformClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformDeleteKeys", h.apiURL+PerformDeleteKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) PerformUploadKeys(
ctx context.Context,
request *api.PerformUploadKeysRequest,
response *api.PerformUploadKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadKeys")
defer span.Finish()
apiURL := h.apiURL + PerformUploadKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformUploadKeys", h.apiURL+PerformUploadKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QueryKeys(
ctx context.Context,
request *api.QueryKeysRequest,
response *api.QueryKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys")
defer span.Finish()
apiURL := h.apiURL + QueryKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QueryKeys", h.apiURL+QueryKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QueryOneTimeKeys(
ctx context.Context,
request *api.QueryOneTimeKeysRequest,
response *api.QueryOneTimeKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys")
defer span.Finish()
apiURL := h.apiURL + QueryOneTimeKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QueryOneTimeKeys", h.apiURL+QueryOneTimeKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QueryDeviceMessages(
ctx context.Context,
request *api.QueryDeviceMessagesRequest,
response *api.QueryDeviceMessagesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceMessages")
defer span.Finish()
apiURL := h.apiURL + QueryDeviceMessagesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QueryDeviceMessages", h.apiURL+QueryDeviceMessagesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QueryKeyChanges(
ctx context.Context,
request *api.QueryKeyChangesRequest,
response *api.QueryKeyChangesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges")
defer span.Finish()
apiURL := h.apiURL + QueryKeyChangesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QueryKeyChanges", h.apiURL+QueryKeyChangesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) PerformUploadDeviceKeys(
ctx context.Context,
request *api.PerformUploadDeviceKeysRequest,
response *api.PerformUploadDeviceKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceKeys")
defer span.Finish()
apiURL := h.apiURL + PerformUploadDeviceKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformUploadDeviceKeys", h.apiURL+PerformUploadDeviceKeysPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures(
ctx context.Context,
request *api.PerformUploadDeviceSignaturesRequest,
response *api.PerformUploadDeviceSignaturesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceSignatures")
defer span.Finish()
apiURL := h.apiURL + PerformUploadDeviceSignaturesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformUploadDeviceSignatures", h.apiURL+PerformUploadDeviceSignaturesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpKeyInternalAPI) QuerySignatures(
ctx context.Context,
request *api.QuerySignaturesRequest,
response *api.QuerySignaturesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySignatures")
defer span.Finish()
apiURL := h.apiURL + QuerySignaturesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
) error {
return httputil.CallInternalRPCAPI(
"QuerySignatures", h.apiURL+QuerySignaturesPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -15,124 +15,59 @@
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/util"
)
func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
internalAPIMux.Handle(PerformClaimKeysPath,
httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformClaimKeysRequest{}
response := api.PerformClaimKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformClaimKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformClaimKeysPath,
httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys),
)
internalAPIMux.Handle(PerformDeleteKeysPath,
httputil.MakeInternalAPI("performDeleteKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformDeleteKeysRequest{}
response := api.PerformDeleteKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformDeleteKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformDeleteKeysPath,
httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys),
)
internalAPIMux.Handle(PerformUploadKeysPath,
httputil.MakeInternalAPI("performUploadKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadKeysRequest{}
response := api.PerformUploadKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformUploadKeysPath,
httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys),
)
internalAPIMux.Handle(PerformUploadDeviceKeysPath,
httputil.MakeInternalAPI("performUploadDeviceKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadDeviceKeysRequest{}
response := api.PerformUploadDeviceKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadDeviceKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformUploadDeviceKeysPath,
httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys),
)
internalAPIMux.Handle(PerformUploadDeviceSignaturesPath,
httputil.MakeInternalAPI("performUploadDeviceSignatures", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadDeviceSignaturesRequest{}
response := api.PerformUploadDeviceSignaturesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadDeviceSignatures(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
PerformUploadDeviceSignaturesPath,
httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures),
)
internalAPIMux.Handle(QueryKeysPath,
httputil.MakeInternalAPI("queryKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryKeysRequest{}
response := api.QueryKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryKeysPath,
httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys),
)
internalAPIMux.Handle(QueryOneTimeKeysPath,
httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryOneTimeKeysRequest{}
response := api.QueryOneTimeKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryOneTimeKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryOneTimeKeysPath,
httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys),
)
internalAPIMux.Handle(QueryDeviceMessagesPath,
httputil.MakeInternalAPI("queryDeviceMessages", func(req *http.Request) util.JSONResponse {
request := api.QueryDeviceMessagesRequest{}
response := api.QueryDeviceMessagesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryDeviceMessages(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryDeviceMessagesPath,
httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages),
)
internalAPIMux.Handle(QueryKeyChangesPath,
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
request := api.QueryKeyChangesRequest{}
response := api.QueryKeyChangesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeyChanges(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryKeyChangesPath,
httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges),
)
internalAPIMux.Handle(QuerySignaturesPath,
httputil.MakeInternalAPI("querySignatures", func(req *http.Request) util.JSONResponse {
request := api.QuerySignaturesRequest{}
response := api.QuerySignaturesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QuerySignatures(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QuerySignaturesPath,
httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures),
)
}

View file

@ -40,7 +40,7 @@ type InputRoomEventsAPI interface {
ctx context.Context,
req *InputRoomEventsRequest,
res *InputRoomEventsResponse,
)
) error
}
// Query the latest events and state for a room from the room server.
@ -97,6 +97,14 @@ type SyncRoomserverAPI interface {
req *PerformBackfillRequest,
res *PerformBackfillResponse,
) error
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to a slice of gomatrixserverlib.HeaderedEvent.
QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error
}
type AppserviceRoomserverAPI interface {
@ -139,15 +147,15 @@ type ClientRoomserverAPI interface {
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
// PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse)
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse)
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse)
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse)
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse)
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse)
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse)
PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error
// PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error
@ -158,7 +166,7 @@ type UserRoomserverAPI interface {
QueryLatestEventsAndStateAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse)
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
}
type FederationRoomserverAPI interface {

View file

@ -35,9 +35,10 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents(
ctx context.Context,
req *InputRoomEventsRequest,
res *InputRoomEventsResponse,
) {
t.Impl.InputRoomEvents(ctx, req, res)
util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.InputRoomEvents(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformInvite(
@ -45,44 +46,49 @@ func (t *RoomserverInternalAPITrace) PerformInvite(
req *PerformInviteRequest,
res *PerformInviteResponse,
) error {
util.GetLogger(ctx).Infof("PerformInvite req=%+v res=%+v", js(req), js(res))
return t.Impl.PerformInvite(ctx, req, res)
err := t.Impl.PerformInvite(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformInvite req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformPeek(
ctx context.Context,
req *PerformPeekRequest,
res *PerformPeekResponse,
) {
t.Impl.PerformPeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPeek req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformPeek(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformPeek req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformUnpeek(
ctx context.Context,
req *PerformUnpeekRequest,
res *PerformUnpeekResponse,
) {
t.Impl.PerformUnpeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformUnpeek(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformRoomUpgrade(
ctx context.Context,
req *PerformRoomUpgradeRequest,
res *PerformRoomUpgradeResponse,
) {
t.Impl.PerformRoomUpgrade(ctx, req, res)
util.GetLogger(ctx).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformRoomUpgrade(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformJoin(
ctx context.Context,
req *PerformJoinRequest,
res *PerformJoinResponse,
) {
t.Impl.PerformJoin(ctx, req, res)
util.GetLogger(ctx).Infof("PerformJoin req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformJoin(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformLeave(
@ -99,27 +105,30 @@ func (t *RoomserverInternalAPITrace) PerformPublish(
ctx context.Context,
req *PerformPublishRequest,
res *PerformPublishResponse,
) {
t.Impl.PerformPublish(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformPublish(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom(
ctx context.Context,
req *PerformAdminEvacuateRoomRequest,
res *PerformAdminEvacuateRoomResponse,
) {
t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser(
ctx context.Context,
req *PerformAdminEvacuateUserRequest,
res *PerformAdminEvacuateUserResponse,
) {
t.Impl.PerformAdminEvacuateUser(ctx, req, res)
util.GetLogger(ctx).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res))
) error {
err := t.Impl.PerformAdminEvacuateUser(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformInboundPeek(
@ -128,7 +137,7 @@ func (t *RoomserverInternalAPITrace) PerformInboundPeek(
res *PerformInboundPeekResponse,
) error {
err := t.Impl.PerformInboundPeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res))
util.GetLogger(ctx).WithError(err).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res))
return err
}
@ -373,6 +382,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed(
return err
}
func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error {
err := t.Impl.QueryMembershipAtEvent(ctx, request, response)
util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response))
return err
}
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {

View file

@ -427,3 +427,17 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
}
return nil
}
// QueryMembershipAtEventRequest requests the membership events for a user
// for a list of eventIDs.
type QueryMembershipAtEventRequest struct {
RoomID string
EventIDs []string
UserID string
}
// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
type QueryMembershipAtEventResponse struct {
// Memberships is a map from eventID to a list of events (if any).
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
}

View file

@ -90,7 +90,9 @@ func SendInputRoomEvents(
Asynchronous: async,
}
var response InputRoomEventsResponse
rsAPI.InputRoomEvents(ctx, &request, &response)
if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil {
return err
}
return response.Err()
}

View file

@ -208,6 +208,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
}
func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) {

View file

@ -337,18 +337,18 @@ func (r *Inputer) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) {
) error {
// Queue up the event into the roomserver.
replySub, err := r.queueInputRoomEvents(ctx, request)
if err != nil {
response.ErrMsg = err.Error()
return
return nil
}
// If we aren't waiting for synchronous responses then we can
// give up here, there is nothing further to do.
if replySub == nil {
return
return nil
}
// Otherwise, we'll want to sit and wait for the responses
@ -360,12 +360,14 @@ func (r *Inputer) InputRoomEvents(
msg, err := replySub.NextMsgWithContext(ctx)
if err != nil {
response.ErrMsg = err.Error()
return
return nil
}
if len(msg.Data) > 0 {
response.ErrMsg = string(msg.Data)
}
}
return nil
}
var roomserverInputBackpressure = prometheus.NewGaugeVec(

View file

@ -299,7 +299,7 @@ func (r *Inputer) processRoomEvent(
// allowed at the time, and also to get the history visibility. We won't
// bother doing this if the event was already rejected as it just ends up
// burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if rejectionErr == nil && !isRejected && !softfail {
var err error
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
@ -429,7 +429,7 @@ func (r *Inputer) processStateBefore(
input *api.InputRoomEvent,
missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
historyVisibility = gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive.
historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared.
event := input.Event.Unwrap()
isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("")
var stateBeforeEvent []*gomatrixserverlib.Event

View file

@ -43,21 +43,21 @@ func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
) {
) error {
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return
return nil
}
if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{
Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room %s not found", req.RoomID),
}
return
return nil
}
memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
@ -66,7 +66,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err),
}
return
return nil
}
memberEvents, err := r.DB.Events(ctx, memberNIDs)
@ -75,7 +75,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.Events: %s", err),
}
return
return nil
}
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
@ -89,7 +89,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err),
}
return
return nil
}
prevEvents := latestRes.LatestEvents
@ -104,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Unmarshal: %s", err),
}
return
return nil
}
memberContent.Membership = gomatrixserverlib.Leave
@ -122,7 +122,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Marshal: %s", err),
}
return
return nil
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
@ -131,7 +131,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return
return nil
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes)
@ -140,7 +140,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return
return nil
}
inputEvents = append(inputEvents, api.InputRoomEvent{
@ -160,28 +160,28 @@ func (r *Admin) PerformAdminEvacuateRoom(
Asynchronous: true,
}
inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
}
func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context,
req *api.PerformAdminEvacuateUserRequest,
res *api.PerformAdminEvacuateUserResponse,
) {
) error {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed user ID: %s", err),
}
return
return nil
}
if domain != r.Cfg.Matrix.ServerName {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: "Can only evacuate local users using this endpoint",
}
return
return nil
}
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join)
@ -190,7 +190,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
return
return nil
}
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite)
@ -199,7 +199,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
return
return nil
}
for _, roomID := range append(roomIDs, inviteRoomIDs...) {
@ -214,7 +214,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err),
}
return
return nil
}
if len(outputEvents) == 0 {
continue
@ -224,9 +224,10 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err),
}
return
return nil
}
res.Affected = append(res.Affected, roomID)
}
return nil
}

View file

@ -241,7 +241,9 @@ func (r *Inviter) PerformInvite(
},
}
inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil {
return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
}
if err = inputRes.Err(); err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),

View file

@ -52,7 +52,7 @@ func (r *Joiner) PerformJoin(
ctx context.Context,
req *rsAPI.PerformJoinRequest,
res *rsAPI.PerformJoinResponse,
) {
) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomIDOrAlias,
"user_id": req.UserID,
@ -71,11 +71,12 @@ func (r *Joiner) PerformJoin(
Msg: err.Error(),
}
}
return
return nil
}
logger.Info("User joined room successfully")
res.RoomID = roomID
res.JoinedVia = joinedVia
return nil
}
func (r *Joiner) performJoin(
@ -291,7 +292,12 @@ func (r *Joiner) performJoinRoomByID(
},
}
inputRes := rsAPI.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNoOperation,
Msg: fmt.Sprintf("InputRoomEvents failed: %s", err),
}
}
if err = inputRes.Err(); err != nil {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNotAllowed,

View file

@ -186,7 +186,9 @@ func (r *Leaver) performLeaveRoomByID(
},
}
inputRes := api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
}
if err = inputRes.Err(); err != nil {
return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
}

View file

@ -44,7 +44,7 @@ func (r *Peeker) PerformPeek(
ctx context.Context,
req *api.PerformPeekRequest,
res *api.PerformPeekResponse,
) {
) error {
roomID, err := r.performPeek(ctx, req)
if err != nil {
perr, ok := err.(*api.PerformError)
@ -57,6 +57,7 @@ func (r *Peeker) PerformPeek(
}
}
res.RoomID = roomID
return nil
}
func (r *Peeker) performPeek(

View file

@ -29,11 +29,12 @@ func (r *Publisher) PerformPublish(
ctx context.Context,
req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
) {
) error {
err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public")
if err != nil {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
return nil
}

View file

@ -41,7 +41,7 @@ func (r *Unpeeker) PerformUnpeek(
ctx context.Context,
req *api.PerformUnpeekRequest,
res *api.PerformUnpeekResponse,
) {
) error {
if err := r.performUnpeek(ctx, req); err != nil {
perr, ok := err.(*api.PerformError)
if ok {
@ -52,6 +52,7 @@ func (r *Unpeeker) PerformUnpeek(
}
}
}
return nil
}
func (r *Unpeeker) performUnpeek(

View file

@ -45,12 +45,13 @@ func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context,
req *api.PerformRoomUpgradeRequest,
res *api.PerformRoomUpgradeResponse,
) {
) error {
res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req)
if res.Error != nil {
res.NewRoomID = ""
logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed")
}
return nil
}
func (r *Upgrader) performRoomUpgrade(
@ -286,22 +287,24 @@ func publishNewRoomAndUnpublishOldRoom(
) {
// expose this room in the published room list
var pubNewRoomRes api.PerformPublishResponse
URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: newRoomID,
Visibility: "public",
}, &pubNewRoomRes)
if pubNewRoomRes.Error != nil {
}, &pubNewRoomRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if pubNewRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public")
}
var unpubOldRoomRes api.PerformPublishResponse
// remove the old room from the published room list
URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: oldRoomID,
Visibility: "private",
}, &unpubOldRoomRes)
if unpubOldRoomRes.Error != nil {
}, &unpubOldRoomRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if unpubOldRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private")
}

View file

@ -204,6 +204,54 @@ func (r *Queryer) QueryMembershipForUser(
return err
}
func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context,
request *api.QueryMembershipAtEventRequest,
response *api.QueryMembershipAtEventResponse,
) error {
response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err)
}
if info == nil {
return fmt.Errorf("no roomInfo found")
}
// get the users stateKeyNID
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
if err != nil {
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
}
if _, ok := stateKeyNIDs[request.UserID]; !ok {
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
}
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID])
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
}
for _, eventID := range request.EventIDs {
stateEntry := stateEntries[eventID]
memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
}
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
for i := range memberships {
ev := memberships[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
res = append(res, ev.Headered(info.RoomVersion))
}
}
response.Memberships[eventID] = res
}
return nil
}
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipsForRoom(
ctx context.Context,

View file

@ -3,18 +3,16 @@ package inthttp
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/matrix-org/gomatrixserverlib"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsInputAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
)
const (
@ -63,6 +61,7 @@ const (
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
)
type httpRoomserverInternalAPI struct {
@ -106,11 +105,10 @@ func (h *httpRoomserverInternalAPI) SetRoomAlias(
request *api.SetRoomAliasRequest,
response *api.SetRoomAliasResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverSetRoomAliasPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"SetRoomAlias", h.roomserverURL+RoomserverSetRoomAliasPath,
h.httpClient, ctx, request, response,
)
}
// GetRoomIDForAlias implements RoomserverAliasAPI
@ -119,11 +117,10 @@ func (h *httpRoomserverInternalAPI) GetRoomIDForAlias(
request *api.GetRoomIDForAliasRequest,
response *api.GetRoomIDForAliasResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"GetRoomIDForAlias", h.roomserverURL+RoomserverGetRoomIDForAliasPath,
h.httpClient, ctx, request, response,
)
}
// GetAliasesForRoomID implements RoomserverAliasAPI
@ -132,11 +129,10 @@ func (h *httpRoomserverInternalAPI) GetAliasesForRoomID(
request *api.GetAliasesForRoomIDRequest,
response *api.GetAliasesForRoomIDResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"GetAliasesForRoomID", h.roomserverURL+RoomserverGetAliasesForRoomIDPath,
h.httpClient, ctx, request, response,
)
}
// RemoveRoomAlias implements RoomserverAliasAPI
@ -145,11 +141,10 @@ func (h *httpRoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"RemoveRoomAlias", h.roomserverURL+RoomserverRemoveRoomAliasPath,
h.httpClient, ctx, request, response,
)
}
// InputRoomEvents implements RoomserverInputAPI
@ -157,15 +152,14 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverInputRoomEventsPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
) error {
if err := httputil.CallInternalRPCAPI(
"InputRoomEvents", h.roomserverURL+RoomserverInputRoomEventsPath,
h.httpClient, ctx, request, response,
); err != nil {
response.ErrMsg = err.Error()
}
return nil
}
func (h *httpRoomserverInternalAPI) PerformInvite(
@ -173,45 +167,32 @@ func (h *httpRoomserverInternalAPI) PerformInvite(
request *api.PerformInviteRequest,
response *api.PerformInviteResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInvite")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformInvitePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformInvite", h.roomserverURL+RoomserverPerformInvitePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformJoin(
ctx context.Context,
request *api.PerformJoinRequest,
response *api.PerformJoinResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoin")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformJoinPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformJoin", h.roomserverURL+RoomserverPerformJoinPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformPeek(
ctx context.Context,
request *api.PerformPeekRequest,
response *api.PerformPeekResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPeek")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformPeekPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformPeek", h.roomserverURL+RoomserverPerformPeekPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformInboundPeek(
@ -219,45 +200,32 @@ func (h *httpRoomserverInternalAPI) PerformInboundPeek(
request *api.PerformInboundPeekRequest,
response *api.PerformInboundPeekResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInboundPeek")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformInboundPeekPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformInboundPeek", h.roomserverURL+RoomserverPerformInboundPeekPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformUnpeek(
ctx context.Context,
request *api.PerformUnpeekRequest,
response *api.PerformUnpeekResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUnpeek")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformUnpeekPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformUnpeek", h.roomserverURL+RoomserverPerformUnpeekPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformRoomUpgrade(
ctx context.Context,
request *api.PerformRoomUpgradeRequest,
response *api.PerformRoomUpgradeResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformRoomUpgrade")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformRoomUpgradePath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
) error {
return httputil.CallInternalRPCAPI(
"PerformRoomUpgrade", h.roomserverURL+RoomserverPerformRoomUpgradePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformLeave(
@ -265,62 +233,43 @@ func (h *httpRoomserverInternalAPI) PerformLeave(
request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeave")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformLeavePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformLeave", h.roomserverURL+RoomserverPerformLeavePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformPublish(
ctx context.Context,
req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPublish")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformPublishPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
request *api.PerformPublishRequest,
response *api.PerformPublishResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformPublish", h.roomserverURL+RoomserverPerformPublishPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
request *api.PerformAdminEvacuateRoomRequest,
response *api.PerformAdminEvacuateRoomResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformAdminEvacuateRoom", h.roomserverURL+RoomserverPerformAdminEvacuateRoomPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser(
ctx context.Context,
req *api.PerformAdminEvacuateUserRequest,
res *api.PerformAdminEvacuateUserResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateUser")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateUserPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
request *api.PerformAdminEvacuateUserRequest,
response *api.PerformAdminEvacuateUserResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformAdminEvacuateUser", h.roomserverURL+RoomserverPerformAdminEvacuateUserPath,
h.httpClient, ctx, request, response,
)
}
// QueryLatestEventsAndState implements RoomserverQueryAPI
@ -329,11 +278,10 @@ func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryLatestEventsAndState", h.roomserverURL+RoomserverQueryLatestEventsAndStatePath,
h.httpClient, ctx, request, response,
)
}
// QueryStateAfterEvents implements RoomserverQueryAPI
@ -342,11 +290,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents(
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryStateAfterEvents", h.roomserverURL+RoomserverQueryStateAfterEventsPath,
h.httpClient, ctx, request, response,
)
}
// QueryEventsByID implements RoomserverQueryAPI
@ -355,11 +302,10 @@ func (h *httpRoomserverInternalAPI) QueryEventsByID(
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryEventsByID", h.roomserverURL+RoomserverQueryEventsByIDPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryPublishedRooms(
@ -367,11 +313,10 @@ func (h *httpRoomserverInternalAPI) QueryPublishedRooms(
request *api.QueryPublishedRoomsRequest,
response *api.QueryPublishedRoomsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublishedRooms")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryPublishedRoomsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryPublishedRooms", h.roomserverURL+RoomserverQueryPublishedRoomsPath,
h.httpClient, ctx, request, response,
)
}
// QueryMembershipForUser implements RoomserverQueryAPI
@ -380,11 +325,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipForUser(
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryMembershipForUser", h.roomserverURL+RoomserverQueryMembershipForUserPath,
h.httpClient, ctx, request, response,
)
}
// QueryMembershipsForRoom implements RoomserverQueryAPI
@ -393,11 +337,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom(
request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryMembershipsForRoom", h.roomserverURL+RoomserverQueryMembershipsForRoomPath,
h.httpClient, ctx, request, response,
)
}
// QueryMembershipsForRoom implements RoomserverQueryAPI
@ -406,11 +349,10 @@ func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom(
request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryServerJoinedToRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryServerJoinedToRoom", h.roomserverURL+RoomserverQueryServerJoinedToRoomPath,
h.httpClient, ctx, request, response,
)
}
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
@ -419,11 +361,10 @@ func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryServerAllowedToSeeEvent", h.roomserverURL+RoomserverQueryServerAllowedToSeeEventPath,
h.httpClient, ctx, request, response,
)
}
// QueryMissingEvents implements RoomServerQueryAPI
@ -432,11 +373,10 @@ func (h *httpRoomserverInternalAPI) QueryMissingEvents(
request *api.QueryMissingEventsRequest,
response *api.QueryMissingEventsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryMissingEvents", h.roomserverURL+RoomserverQueryMissingEventsPath,
h.httpClient, ctx, request, response,
)
}
// QueryStateAndAuthChain implements RoomserverQueryAPI
@ -445,11 +385,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain(
request *api.QueryStateAndAuthChainRequest,
response *api.QueryStateAndAuthChainResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryStateAndAuthChain", h.roomserverURL+RoomserverQueryStateAndAuthChainPath,
h.httpClient, ctx, request, response,
)
}
// PerformBackfill implements RoomServerQueryAPI
@ -458,11 +397,10 @@ func (h *httpRoomserverInternalAPI) PerformBackfill(
request *api.PerformBackfillRequest,
response *api.PerformBackfillResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBackfill")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformBackfillPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformBackfill", h.roomserverURL+RoomserverPerformBackfillPath,
h.httpClient, ctx, request, response,
)
}
// QueryRoomVersionCapabilities implements RoomServerQueryAPI
@ -471,11 +409,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities(
request *api.QueryRoomVersionCapabilitiesRequest,
response *api.QueryRoomVersionCapabilitiesResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryRoomVersionCapabilities", h.roomserverURL+RoomserverQueryRoomVersionCapabilitiesPath,
h.httpClient, ctx, request, response,
)
}
// QueryRoomVersionForRoom implements RoomServerQueryAPI
@ -488,12 +425,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom(
response.RoomVersion = roomVersion
return nil
}
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
err := httputil.CallInternalRPCAPI(
"QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath,
h.httpClient, ctx, request, response,
)
if err == nil {
h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
}
@ -505,11 +440,10 @@ func (h *httpRoomserverInternalAPI) QueryCurrentState(
request *api.QueryCurrentStateRequest,
response *api.QueryCurrentStateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryCurrentState", h.roomserverURL+RoomserverQueryCurrentStatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
@ -517,11 +451,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
request *api.QueryRoomsForUserRequest,
response *api.QueryRoomsForUserResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryRoomsForUser", h.roomserverURL+RoomserverQueryRoomsForUserPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
@ -529,68 +462,82 @@ func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
request *api.QueryBulkStateContentRequest,
response *api.QueryBulkStateContentResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryBulkStateContent", h.roomserverURL+RoomserverQueryBulkStateContentPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QuerySharedUsers(
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
ctx context.Context,
request *api.QuerySharedUsersRequest,
response *api.QuerySharedUsersResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QuerySharedUsers", h.roomserverURL+RoomserverQuerySharedUsersPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryKnownUsers(
ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
ctx context.Context,
request *api.QueryKnownUsersRequest,
response *api.QueryKnownUsersResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryKnownUsers", h.roomserverURL+RoomserverQueryKnownUsersPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryAuthChain(
ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse,
ctx context.Context,
request *api.QueryAuthChainRequest,
response *api.QueryAuthChainResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAuthChain")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryAuthChainPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryAuthChain", h.roomserverURL+RoomserverQueryAuthChainPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
ctx context.Context,
request *api.QueryServerBannedFromRoomRequest,
response *api.QueryServerBannedFromRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryServerBannedFromRoom", h.roomserverURL+RoomserverQueryServerBannedFromRoomPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed(
ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse,
ctx context.Context,
request *api.QueryRestrictedJoinAllowedRequest,
response *api.QueryRestrictedJoinAllowedResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRestrictedJoinAllowed")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRestrictedJoinAllowed
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"QueryRestrictedJoinAllowed", h.roomserverURL+RoomserverQueryRestrictedJoinAllowed,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformForgetPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpRoomserverInternalAPI) PerformForget(
ctx context.Context,
request *api.PerformForgetRequest,
response *api.PerformForgetResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformForget", h.roomserverURL+RoomserverPerformForgetPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse) error {
return httputil.CallInternalRPCAPI(
"QueryMembershiptAtEvent", h.roomserverURL+RoomserverQueryMembershipAtEventPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -1,499 +1,201 @@
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util"
)
// AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux.
// nolint: gocyclo
func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(RoomserverInputRoomEventsPath,
httputil.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse {
var request api.InputRoomEventsRequest
var response api.InputRoomEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.InputRoomEvents(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverInputRoomEventsPath,
httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents),
)
internalAPIMux.Handle(RoomserverPerformInvitePath,
httputil.MakeInternalAPI("performInvite", func(req *http.Request) util.JSONResponse {
var request api.PerformInviteRequest
var response api.PerformInviteResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformInvite(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformInvitePath,
httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite),
)
internalAPIMux.Handle(RoomserverPerformJoinPath,
httputil.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse {
var request api.PerformJoinRequest
var response api.PerformJoinResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformJoin(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformJoinPath,
httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin),
)
internalAPIMux.Handle(RoomserverPerformLeavePath,
httputil.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse {
var request api.PerformLeaveRequest
var response api.PerformLeaveResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformLeave(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformLeavePath,
httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave),
)
internalAPIMux.Handle(RoomserverPerformPeekPath,
httputil.MakeInternalAPI("performPeek", func(req *http.Request) util.JSONResponse {
var request api.PerformPeekRequest
var response api.PerformPeekResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformPeek(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformPeekPath,
httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek),
)
internalAPIMux.Handle(RoomserverPerformInboundPeekPath,
httputil.MakeInternalAPI("performInboundPeek", func(req *http.Request) util.JSONResponse {
var request api.PerformInboundPeekRequest
var response api.PerformInboundPeekResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.PerformInboundPeek(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformInboundPeekPath,
httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek),
)
internalAPIMux.Handle(RoomserverPerformPeekPath,
httputil.MakeInternalAPI("performUnpeek", func(req *http.Request) util.JSONResponse {
var request api.PerformUnpeekRequest
var response api.PerformUnpeekResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformUnpeek(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformUnpeekPath,
httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek),
)
internalAPIMux.Handle(RoomserverPerformRoomUpgradePath,
httputil.MakeInternalAPI("performRoomUpgrade", func(req *http.Request) util.JSONResponse {
var request api.PerformRoomUpgradeRequest
var response api.PerformRoomUpgradeResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformRoomUpgrade(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformRoomUpgradePath,
httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade),
)
internalAPIMux.Handle(RoomserverPerformPublishPath,
httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse {
var request api.PerformPublishRequest
var response api.PerformPublishResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformPublish(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformPublishPath,
httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish),
)
internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath,
httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse {
var request api.PerformAdminEvacuateRoomRequest
var response api.PerformAdminEvacuateRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformAdminEvacuateRoom(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformAdminEvacuateRoomPath,
httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom),
)
internalAPIMux.Handle(RoomserverPerformAdminEvacuateUserPath,
httputil.MakeInternalAPI("performAdminEvacuateUser", func(req *http.Request) util.JSONResponse {
var request api.PerformAdminEvacuateUserRequest
var response api.PerformAdminEvacuateUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformAdminEvacuateUser(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverPerformAdminEvacuateUserPath,
httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser),
)
internalAPIMux.Handle(
RoomserverQueryPublishedRoomsPath,
httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse {
var request api.QueryPublishedRoomsRequest
var response api.QueryPublishedRoomsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryPublishedRooms(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms),
)
internalAPIMux.Handle(
RoomserverQueryLatestEventsAndStatePath,
httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
var request api.QueryLatestEventsAndStateRequest
var response api.QueryLatestEventsAndStateResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState),
)
internalAPIMux.Handle(
RoomserverQueryStateAfterEventsPath,
httputil.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAfterEventsRequest
var response api.QueryStateAfterEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents),
)
internalAPIMux.Handle(
RoomserverQueryEventsByIDPath,
httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
var request api.QueryEventsByIDRequest
var response api.QueryEventsByIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID),
)
internalAPIMux.Handle(
RoomserverQueryMembershipForUserPath,
httputil.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipForUserRequest
var response api.QueryMembershipForUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser),
)
internalAPIMux.Handle(
RoomserverQueryMembershipsForRoomPath,
httputil.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipsForRoomRequest
var response api.QueryMembershipsForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom),
)
internalAPIMux.Handle(
RoomserverQueryServerJoinedToRoomPath,
httputil.MakeInternalAPI("queryServerJoinedToRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryServerJoinedToRoomRequest
var response api.QueryServerJoinedToRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom),
)
internalAPIMux.Handle(
RoomserverQueryServerAllowedToSeeEventPath,
httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
var request api.QueryServerAllowedToSeeEventRequest
var response api.QueryServerAllowedToSeeEventResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent),
)
internalAPIMux.Handle(
RoomserverQueryMissingEventsPath,
httputil.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryMissingEventsRequest
var response api.QueryMissingEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents),
)
internalAPIMux.Handle(
RoomserverQueryStateAndAuthChainPath,
httputil.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAndAuthChainRequest
var response api.QueryStateAndAuthChainResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain),
)
internalAPIMux.Handle(
RoomserverPerformBackfillPath,
httputil.MakeInternalAPI("PerformBackfill", func(req *http.Request) util.JSONResponse {
var request api.PerformBackfillRequest
var response api.PerformBackfillResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.PerformBackfill(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill),
)
internalAPIMux.Handle(
RoomserverPerformForgetPath,
httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse {
var request api.PerformForgetRequest
var response api.PerformForgetResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.PerformForget(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget),
)
internalAPIMux.Handle(
RoomserverQueryRoomVersionCapabilitiesPath,
httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse {
var request api.QueryRoomVersionCapabilitiesRequest
var response api.QueryRoomVersionCapabilitiesResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities),
)
internalAPIMux.Handle(
RoomserverQueryRoomVersionForRoomPath,
httputil.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryRoomVersionForRoomRequest
var response api.QueryRoomVersionForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom),
)
internalAPIMux.Handle(
RoomserverSetRoomAliasPath,
httputil.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.SetRoomAliasRequest
var response api.SetRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias),
)
internalAPIMux.Handle(
RoomserverGetRoomIDForAliasPath,
httputil.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse {
var request api.GetRoomIDForAliasRequest
var response api.GetRoomIDForAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias),
)
internalAPIMux.Handle(
RoomserverGetAliasesForRoomIDPath,
httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse {
var request api.GetAliasesForRoomIDRequest
var response api.GetAliasesForRoomIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID),
)
internalAPIMux.Handle(
RoomserverRemoveRoomAliasPath,
httputil.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.RemoveRoomAliasRequest
var response api.RemoveRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias),
)
internalAPIMux.Handle(RoomserverQueryCurrentStatePath,
httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse {
request := api.QueryCurrentStateRequest{}
response := api.QueryCurrentStateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryCurrentStatePath,
httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState),
)
internalAPIMux.Handle(RoomserverQueryRoomsForUserPath,
httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
request := api.QueryRoomsForUserRequest{}
response := api.QueryRoomsForUserResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryRoomsForUserPath,
httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser),
)
internalAPIMux.Handle(RoomserverQueryBulkStateContentPath,
httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
request := api.QueryBulkStateContentRequest{}
response := api.QueryBulkStateContentResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryBulkStateContentPath,
httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent),
)
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
request := api.QuerySharedUsersRequest{}
response := api.QuerySharedUsersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQuerySharedUsersPath,
httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers),
)
internalAPIMux.Handle(RoomserverQueryKnownUsersPath,
httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
request := api.QueryKnownUsersRequest{}
response := api.QueryKnownUsersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryKnownUsersPath,
httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers),
)
internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath,
httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse {
request := api.QueryServerBannedFromRoomRequest{}
response := api.QueryServerBannedFromRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryServerBannedFromRoomPath,
httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom),
)
internalAPIMux.Handle(RoomserverQueryAuthChainPath,
httputil.MakeInternalAPI("queryAuthChain", func(req *http.Request) util.JSONResponse {
request := api.QueryAuthChainRequest{}
response := api.QueryAuthChainResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryAuthChainPath,
httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain),
)
internalAPIMux.Handle(RoomserverQueryRestrictedJoinAllowed,
httputil.MakeInternalAPI("queryRestrictedJoinAllowed", func(req *http.Request) util.JSONResponse {
request := api.QueryRestrictedJoinAllowedRequest{}
response := api.QueryRestrictedJoinAllowedResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryRestrictedJoinAllowed(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
RoomserverQueryRestrictedJoinAllowed,
httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed),
)
internalAPIMux.Handle(
RoomserverQueryMembershipAtEventPath,
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent),
)
}

View file

@ -23,12 +23,11 @@ import (
"sync"
"time"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type StateResolutionStorage interface {
@ -124,6 +123,61 @@ func (v *StateResolution) LoadStateAtEvent(
return stateEntries, nil
}
func (v *StateResolution) LoadMembershipAtEvent(
ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID,
) (map[string][]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
defer span.Finish()
// De-dupe snapshotNIDs
snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs
for i := range eventIDs {
eventID := eventIDs[i]
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
}
if snapshotNID == 0 {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
}
snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID)
}
snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap))
for nid := range snapshotNIDMap {
snapshotNIDs = append(snapshotNIDs, nid)
}
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, snapshotNIDs)
if err != nil {
return nil, err
}
result := make(map[string][]types.StateEntry)
for _, stateBlockNIDList := range stateBlockNIDLists {
// Query the membership event for the user at the given stateblocks
stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{
{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
},
})
if err != nil {
return nil, err
}
evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID]
for _, evID := range evIDs {
for _, x := range stateEntryLists {
result[evID] = append(result[evID], x.StateEntries...)
}
}
}
return result, nil
}
// LoadStateAtEvent loads the full state of a room before a particular event.
func (v *StateResolution) LoadStateAtEventForHistoryVisibility(
ctx context.Context, eventID string,

View file

@ -23,7 +23,7 @@ import (
// DefaultRoomVersion contains the room version that will, by
// default, be used to create new rooms on this server.
func DefaultRoomVersion() gomatrixserverlib.RoomVersion {
return gomatrixserverlib.RoomVersionV6
return gomatrixserverlib.RoomVersionV9
}
// RoomVersions returns a map of all known room versions to this

View file

@ -164,9 +164,9 @@ func TestMSC2836(t *testing.T) {
// make everyone joined to each other's rooms
nopRsAPI := &testRoomserverAPI{
userToJoinedRooms: map[string][]string{
alice: []string{roomID},
bob: []string{roomID},
charlie: []string{roomID},
alice: {roomID},
bob: {roomID},
charlie: {roomID},
},
events: map[string]*gomatrixserverlib.HeaderedEvent{
eventA.EventID(): eventA,

View file

@ -0,0 +1,217 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"math"
"time"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
"github.com/tidwall/gjson"
)
func init() {
prometheus.MustRegister(calculateHistoryVisibilityDuration)
}
// calculateHistoryVisibilityDuration stores the time it takes to
// calculate the history visibility. In polylith mode the roundtrip
// to the roomserver is included in this time.
var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "syncapi",
Name: "calculateHistoryVisibility_duration_millis",
Help: "How long it takes to calculate the history visibility",
Buckets: []float64{ // milliseconds
5, 10, 25, 50, 75, 100, 250, 500,
1000, 2000, 3000, 4000, 5000, 6000,
7000, 8000, 9000, 10000, 15000, 20000,
},
},
[]string{"api"},
)
var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{
gomatrixserverlib.WorldReadable: 0,
gomatrixserverlib.HistoryVisibilityShared: 1,
gomatrixserverlib.HistoryVisibilityInvited: 2,
gomatrixserverlib.HistoryVisibilityJoined: 3,
}
// eventVisibility contains the history visibility and membership state at a given event
type eventVisibility struct {
visibility gomatrixserverlib.HistoryVisibility
membershipAtEvent string
membershipCurrent string
}
// allowed checks the eventVisibility if the user is allowed to see the event.
// Rules as defined by https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
func (ev eventVisibility) allowed() (allowed bool) {
switch ev.visibility {
case gomatrixserverlib.HistoryVisibilityWorldReadable:
// If the history_visibility was set to world_readable, allow.
return true
case gomatrixserverlib.HistoryVisibilityJoined:
// If the users membership was join, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join {
return true
}
return false
case gomatrixserverlib.HistoryVisibilityShared:
// If the users membership was join, allow.
// If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join {
return true
}
return false
case gomatrixserverlib.HistoryVisibilityInvited:
// If the users membership was join, allow.
if ev.membershipAtEvent == gomatrixserverlib.Join {
return true
}
if ev.membershipAtEvent == gomatrixserverlib.Invite {
return true
}
return false
default:
return false
}
}
// ApplyHistoryVisibilityFilter applies the room history visibility filter on gomatrixserverlib.HeaderedEvents.
// Returns the filtered events and an error, if any.
func ApplyHistoryVisibilityFilter(
ctx context.Context,
syncDB storage.Database,
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{},
userID, endpoint string,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
if len(events) == 0 {
return events, nil
}
start := time.Now()
// try to get the current membership of the user
membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64)
if err != nil {
return nil, err
}
// Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events))
visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID())
if err != nil {
return eventsFiltered, err
}
for _, ev := range events {
evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent
// Always include specific state events for /sync responses
if alwaysIncludeEventIDs != nil {
if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok {
eventsFiltered = append(eventsFiltered, ev)
continue
}
}
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) {
eventsFiltered = append(eventsFiltered, ev)
continue
}
// Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive
// of the old vs new.
// https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
if hisVis, err := ev.HistoryVisibility(); err == nil {
prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String()
oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)]
// if we can't get the previous history visibility, default to shared.
if !ok {
oldPrio = historyVisibilityPriority[gomatrixserverlib.HistoryVisibilityShared]
}
// no OK check, since this should have been validated when setting the value
newPrio := historyVisibilityPriority[hisVis]
if oldPrio < newPrio {
evVis.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis)
}
}
// do the actual check
allowed := evVis.allowed()
if allowed {
eventsFiltered = append(eventsFiltered, ev)
}
}
calculateHistoryVisibilityDuration.With(prometheus.Labels{"api": endpoint}).Observe(float64(time.Since(start).Milliseconds()))
return eventsFiltered, nil
}
// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership
// of `userID` at the given event.
// Returns an error if the roomserver can't calculate the memberships.
func visibilityForEvents(
ctx context.Context,
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
userID, roomID string,
) (map[string]eventVisibility, error) {
eventIDs := make([]string, len(events))
for i := range events {
eventIDs[i] = events[i].EventID()
}
result := make(map[string]eventVisibility, len(eventIDs))
// get the membership events for all eventIDs
membershipResp := &api.QueryMembershipAtEventResponse{}
err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{
RoomID: roomID,
EventIDs: eventIDs,
UserID: userID,
}, membershipResp)
if err != nil {
return result, err
}
// Create a map from eventID -> eventVisibility
for _, event := range events {
eventID := event.EventID()
vis := eventVisibility{
membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility,
}
membershipEvs, ok := membershipResp.Memberships[eventID]
if !ok {
result[eventID] = vis
continue
}
for _, ev := range membershipEvs {
membership, err := ev.Membership()
if err != nil {
return result, err
}
vis.membershipAtEvent = membership
}
result[eventID] = vis
}
return result, nil
}

View file

@ -31,7 +31,7 @@ import (
// DeviceOTKCounts adds one-time key counts to the /sync response
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
var queryRes keyapi.QueryOneTimeKeysResponse
keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
_ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
UserID: userID,
DeviceID: deviceID,
}, &queryRes)
@ -73,7 +73,7 @@ func DeviceListCatchup(
offset = int64(from)
}
var queryRes keyapi.QueryKeyChangesResponse
keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
_ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
Offset: offset,
ToOffset: toOffset,
}, &queryRes)

View file

@ -22,31 +22,41 @@ var (
type mockKeyAPI struct{}
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error {
return nil
}
func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {}
// PerformClaimKeys claims one-time keys for use in pre-key messages
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) {
func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) {
func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) {
func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error {
return nil
}
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error {
return nil
}
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
return nil
}
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
return nil
}
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) {
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error {
return nil
}
func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) {
func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error {
return nil
}
type mockRoomserverAPI struct {

View file

@ -21,10 +21,12 @@ import (
"fmt"
"net/http"
"strconv"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
@ -95,24 +97,6 @@ func Context(
ContainsURL: filter.ContainsURL,
}
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
state, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
// verify the user is allowed to see the context for this room/event
for _, x := range state {
var hisVis gomatrixserverlib.HistoryVisibility
hisVis, err = x.HistoryVisibility()
if err != nil {
continue
}
allowed := hisVis == gomatrixserverlib.WorldReadable || membershipRes.Membership == gomatrixserverlib.Join
if !allowed {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("User is not allowed to query context"),
}
}
}
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
if err != nil {
if err == sql.ErrNoRows {
@ -125,6 +109,24 @@ func Context(
return jsonerror.InternalServerError()
}
// verify the user is allowed to see the context for this room/event
startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (context)")
if len(filteredEvents) == 0 {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("User is not allowed to query context"),
}
}
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch before events")
@ -137,8 +139,27 @@ func Context(
return jsonerror.InternalServerError()
}
eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatAll)
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatAll)
startTime = time.Now()
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (context eventsBefore/eventsAfter)")
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
if err != nil {
logrus.WithError(err).Error("unable to fetch current room state")
return jsonerror.InternalServerError()
}
eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll)
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll)
newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache)
response := ContextRespsonse{
@ -162,6 +183,44 @@ func Context(
}
}
// applyHistoryVisibilityOnContextEvents is a helper function to avoid roundtrips to the roomserver
// by combining the events before and after the context event. Returns the filtered events,
// and an error, if any.
func applyHistoryVisibilityOnContextEvents(
ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
userID string,
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
eventIDsBefore := make(map[string]struct{}, len(eventsBefore))
eventIDsAfter := make(map[string]struct{}, len(eventsAfter))
// Remember before/after eventIDs, so we can restore them
// after applying history visibility checks
for _, ev := range eventsBefore {
eventIDsBefore[ev.EventID()] = struct{}{}
}
for _, ev := range eventsAfter {
eventIDsAfter[ev.EventID()] = struct{}{}
}
allEvents := append(eventsBefore, eventsAfter...)
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context")
if err != nil {
return nil, nil, err
}
// "Restore" events in the correct context
for _, ev := range filteredEvents {
if _, ok := eventIDsBefore[ev.EventID()]; ok {
filteredBefore = append(filteredBefore, ev)
}
if _, ok := eventIDsAfter[ev.EventID()]; ok {
filteredAfter = append(filteredAfter, ev)
}
}
return filteredBefore, filteredAfter, nil
}
func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if len(startEvents) > 0 {
start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID())

View file

@ -19,6 +19,7 @@ import (
"fmt"
"net/http"
"sort"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -28,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
@ -324,6 +326,9 @@ func (r *messagesReq) retrieveEvents() (
// reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events)
if err != nil {
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err
}
// Sort the events to ensure we send them in the right order.
if r.backwardOrdering {
@ -337,97 +342,18 @@ func (r *messagesReq) retrieveEvents() (
}
events = reversed(events)
}
events = r.filterHistoryVisible(events)
if len(events) == 0 {
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil
}
// Convert all of the events into client events.
clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll)
return clientEvents, start, end, err
}
func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
// user shouldn't see, we check the recent events and remove any prior to the join event of the user
// which is equiv to history_visibility: joined
joinEventIndex := -1
for i, ev := range events {
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(r.device.UserID) {
membership, _ := ev.Membership()
if membership == "join" {
joinEventIndex = i
break
}
}
}
var result []*gomatrixserverlib.HeaderedEvent
var eventsToCheck []*gomatrixserverlib.HeaderedEvent
if joinEventIndex != -1 {
if r.backwardOrdering {
result = events[:joinEventIndex+1]
eventsToCheck = append(eventsToCheck, result[0])
} else {
result = events[joinEventIndex:]
eventsToCheck = append(eventsToCheck, result[len(result)-1])
}
} else {
eventsToCheck = []*gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]}
result = events
}
// make sure the user was in the room for both the earliest and latest events, we need this because
// some backpagination results will not have the join event (e.g if they hit /messages at the join event itself)
wasJoined := true
for _, ev := range eventsToCheck {
var queryRes api.QueryStateAfterEventsResponse
err := r.rsAPI.QueryStateAfterEvents(r.ctx, &api.QueryStateAfterEventsRequest{
RoomID: ev.RoomID(),
PrevEventIDs: ev.PrevEventIDs(),
StateToFetch: []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomMember, StateKey: r.device.UserID},
{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""},
},
}, &queryRes)
if err != nil {
wasJoined = false
break
}
var hisVisEvent, membershipEvent *gomatrixserverlib.HeaderedEvent
for i := range queryRes.StateEvents {
switch queryRes.StateEvents[i].Type() {
case gomatrixserverlib.MRoomMember:
membershipEvent = queryRes.StateEvents[i]
case gomatrixserverlib.MRoomHistoryVisibility:
hisVisEvent = queryRes.StateEvents[i]
}
}
if hisVisEvent == nil {
return events // apply no filtering as it defaults to Shared.
}
hisVis, _ := hisVisEvent.HistoryVisibility()
if hisVis == "shared" || hisVis == "world_readable" {
return events // apply no filtering
}
if membershipEvent == nil {
wasJoined = false
break
}
membership, err := membershipEvent.Membership()
if err != nil {
wasJoined = false
break
}
if membership != "join" {
wasJoined = false
break
}
}
if !wasJoined {
util.GetLogger(r.ctx).WithField("num_events", len(events)).Warnf("%s was not joined to room during these events, omitting them", r.device.UserID)
return []*gomatrixserverlib.HeaderedEvent{}
}
return result
// Apply room history visibility filter
startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": r.roomID,
}).Debug("applied history visibility (messages)")
return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err
}
func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {

View file

@ -161,6 +161,10 @@ type Database interface {
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
}
type Presence interface {

View file

@ -17,7 +17,10 @@ package deltas
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
@ -31,6 +34,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T
return nil
}
// UpSetHistoryVisibility sets the history visibility for already stored events.
// Requires current_room_state and output_room_events to be created.
func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error {
// get the current room history visibilities
historyVisibilities, err := currentHistoryVisibilities(ctx, tx)
if err != nil {
return err
}
// update the history visibility
for roomID, hisVis := range historyVisibilities {
_, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1
WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID)
if err != nil {
return fmt.Errorf("failed to update history visibility: %w", err)
}
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
@ -39,9 +63,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
// currentHistoryVisibilities returns a map from roomID to current history visibility.
// If the history visibility was changed after room creation, defaults to joined.
func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) {
rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state
WHERE type = 'm.room.history_visibility' AND state_key = '';
`)
if err != nil {
return nil, fmt.Errorf("failed to query current room state: %w", err)
}
defer rows.Close() // nolint: errcheck
var eventBytes []byte
var roomID string
var event gomatrixserverlib.HeaderedEvent
var hisVis gomatrixserverlib.HistoryVisibility
historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility)
for rows.Next() {
if err = rows.Scan(&roomID, &eventBytes); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
if err = json.Unmarshal(eventBytes, &event); err != nil {
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined
if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 {
historyVisibilities[roomID] = hisVis
}
}
return historyVisibilities, nil
}
func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility;

View file

@ -66,10 +66,14 @@ const selectMembershipCountSQL = "" +
const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
selectHeroesStmt *sql.Stmt
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
selectHeroesStmt *sql.Stmt
selectMembershipForUserStmt *sql.Stmt
}
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -82,6 +86,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectHeroesStmt, selectHeroesSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
}.Prepare(db)
}
@ -132,3 +137,20 @@ func (s *membershipsStatements) SelectHeroes(
}
return heroes, rows.Err()
}
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
func (s *membershipsStatements) SelectMembershipForUser(
ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64,
) (membership string, topologyPos int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos)
if err != nil {
if err == sql.ErrNoRows {
return "leave", 0, nil
}
return "", 0, err
}
return membership, topologyPos, nil
}

View file

@ -191,10 +191,12 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
})
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
},
)
err = m.Up(context.Background())
if err != nil {
return nil, err

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
)
@ -97,6 +98,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil {
return nil, err
}
// apply migrations which need multiple tables
m := sqlutil.NewMigrator(d.db)
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: set history visibility for existing events",
Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created.
},
)
err = m.Up(base.Context())
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
Writer: d.writer,

View file

@ -231,7 +231,7 @@ func (d *Database) AddPeek(
return
}
// DeletePeeks tracks the fact that a user has stopped peeking from the specified
// DeletePeek tracks the fact that a user has stopped peeking from the specified
// device. If the peeks was successfully deleted this returns the stream ID it was
// stored at. Returns an error if there was a problem communicating with the database.
func (d *Database) DeletePeek(
@ -372,6 +372,7 @@ func (d *Database) WriteEvent(
) (pduPosition types.StreamPosition, returnErr error) {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
ev.Visibility = historyVisibility
pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
)
@ -563,7 +564,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
return err
}
// Retrieve the backward topology position, i.e. the position of the
// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
// oldest event in the room's topology.
func (d *Database) GetBackwardTopologyPos(
ctx context.Context,
@ -674,7 +675,7 @@ func (d *Database) fetchMissingStateEvents(
return events, nil
}
// getStateDeltas returns the state deltas between fromPos and toPos,
// GetStateDeltas returns the state deltas between fromPos and toPos,
// exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it.
@ -812,7 +813,7 @@ func (d *Database) GetStateDeltas(
return deltas, joinedRoomIDs, nil
}
// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms.
@ -1039,37 +1040,41 @@ func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID s
return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to)
}
func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
}
func (s *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
}
func (s *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
}
func (s *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
return s.Ignores.SelectIgnores(ctx, userID)
func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
return d.Ignores.SelectIgnores(ctx, userID)
}
func (s *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
return s.Ignores.UpsertIgnores(ctx, userID, ignores)
func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
return d.Ignores.UpsertIgnores(ctx, userID, ignores)
}
func (s *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
return s.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
return d.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
}
func (s *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return s.Presence.GetPresenceForUser(ctx, nil, userID)
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, nil, userID)
}
func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return s.Presence.GetPresenceAfter(ctx, nil, after, filter)
func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return d.Presence.GetPresenceAfter(ctx, nil, after, filter)
}
func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return s.Presence.GetMaxPresenceID(ctx, nil)
func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return d.Presence.GetMaxPresenceID(ctx, nil)
}
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
}

View file

@ -17,7 +17,10 @@ package deltas
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
@ -37,6 +40,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T
return nil
}
// UpSetHistoryVisibility sets the history visibility for already stored events.
// Requires current_room_state and output_room_events to be created.
func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error {
// get the current room history visibilities
historyVisibilities, err := currentHistoryVisibilities(ctx, tx)
if err != nil {
return err
}
// update the history visibility
for roomID, hisVis := range historyVisibilities {
_, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1
WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID)
if err != nil {
return fmt.Errorf("failed to update history visibility: %w", err)
}
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
// Required for unit tests, as otherwise a duplicate column error will show up.
@ -51,9 +75,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
// currentHistoryVisibilities returns a map from roomID to current history visibility.
// If the history visibility was changed after room creation, defaults to joined.
func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) {
rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state
WHERE type = 'm.room.history_visibility' AND state_key = '';
`)
if err != nil {
return nil, fmt.Errorf("failed to query current room state: %w", err)
}
defer rows.Close() // nolint: errcheck
var eventBytes []byte
var roomID string
var event gomatrixserverlib.HeaderedEvent
var hisVis gomatrixserverlib.HistoryVisibility
historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility)
for rows.Next() {
if err = rows.Scan(&roomID, &eventBytes); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
if err = json.Unmarshal(eventBytes, &event); err != nil {
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined
if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 {
historyVisibilities[roomID] = hisVis
}
}
return historyVisibilities, nil
}
func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists.
_, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")

View file

@ -66,11 +66,15 @@ const selectMembershipCountSQL = "" +
const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
type membershipsStatements struct {
db *sql.DB
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
//selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic
selectMembershipForUserStmt *sql.Stmt
}
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -84,6 +88,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
return s, sqlutil.StatementList{
{&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
// {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
}.Prepare(db)
}
@ -148,3 +153,20 @@ func (s *membershipsStatements) SelectHeroes(
}
return heroes, rows.Err()
}
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
func (s *membershipsStatements) SelectMembershipForUser(
ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64,
) (membership string, topologyPos int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos)
if err != nil {
if err == sql.ErrNoRows {
return "leave", 0, nil
}
return "", 0, err
}
return membership, topologyPos, nil
}

View file

@ -139,10 +139,12 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
})
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
},
)
err = m.Up(context.Background())
if err != nil {
return nil, err

View file

@ -16,12 +16,14 @@
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
)
// SyncServerDatasource represents a sync server datasource which manages
@ -41,13 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
if err = d.prepare(); err != nil {
if err = d.prepare(base.Context()); err != nil {
return nil, err
}
return &d, nil
}
func (d *SyncServerDatasource) prepare() (err error) {
func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
if err = d.streamID.Prepare(d.db); err != nil {
return err
}
@ -107,6 +109,19 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil {
return err
}
// apply migrations which need multiple tables
m := sqlutil.NewMigrator(d.db)
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: set history visibility for existing events",
Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created.
},
)
err = m.Up(ctx)
if err != nil {
return err
}
d.Database = shared.Database{
DB: d.db,
Writer: d.writer,

View file

@ -12,20 +12,22 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
)
var ctx = context.Background()
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func(), func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewSyncServerDatasource(nil, &config.DatabaseOptions{
base, closeBase := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewSyncServerDatasource(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
return db, close
return db, close, closeBase
}
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
@ -51,8 +53,9 @@ func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType)
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
MustWriteEvents(t, db, r.Events())
})
}
@ -60,8 +63,9 @@ func TestWriteEvents(t *testing.T) {
// These tests assert basic functionality of RecentEvents for PDUs
func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
alice := test.NewUser(t)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
@ -163,8 +167,9 @@ func TestRecentEventsPDU(t *testing.T) {
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ {
@ -404,8 +409,9 @@ func TestSendToDeviceBehaviour(t *testing.T) {
bob := test.NewUser(t)
deviceID := "one"
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
// At this point there should be no messages. We haven't sent anything
// yet.
_, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)

View file

@ -185,6 +185,7 @@ type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
}
type NotificationData interface {

View file

@ -10,10 +10,13 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"go.uber.org/atomic"
@ -123,7 +126,7 @@ func (p *PDUStreamProvider) CompleteSync(
defer reqWaitGroup.Done()
jr, jerr := p.getJoinResponseForCompleteSync(
ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device,
ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
)
if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
@ -149,7 +152,7 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted {
var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync(
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device,
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
)
if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -281,12 +284,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
}
}
if len(recentEvents) > 0 {
updateLatestPosition(recentEvents[len(recentEvents)-1].EventID())
}
if len(delta.StateEvents) > 0 {
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
}
if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers(
@ -306,6 +303,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
}
// Applies the history visibility rules
events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
if len(events) > 0 {
updateLatestPosition(events[len(events)-1].EventID())
}
if len(delta.StateEvents) > 0 {
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
}
switch delta.Membership {
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
@ -313,14 +323,17 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition)
}
jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.RoomID] = *jr
case gomatrixserverlib.Peek:
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
@ -330,12 +343,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
fallthrough // transitions to leave are the same as ban
case gomatrixserverlib.Ban:
// TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room.
lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Leave[delta.RoomID] = *lr
}
@ -343,6 +356,41 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
return latestPosition, nil
}
// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make
// sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter(
ctx context.Context,
db storage.Database,
rsAPI roomserverAPI.SyncRoomserverAPI,
roomID, userID string,
limit int,
recentEvents []*gomatrixserverlib.HeaderedEvent,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We need to make sure we always include the latest states events, if they are in the timeline.
// We grep at least limit * 2 events, to ensure we really get the needed events.
stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
if err != nil {
// Not a fatal error, we can continue without the stateEvents,
// they are only needed if there are state events in the timeline.
logrus.WithError(err).Warnf("failed to get current room state")
}
alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents))
for _, ev := range stateEvents {
alwaysIncludeIDs[ev.EventID()] = struct{}{}
}
startTime := time.Now()
events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
if err != nil {
return nil, err
}
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": roomID,
}).Debug("applied history visibility (sync)")
return events, nil
}
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
@ -390,6 +438,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
eventFilter *gomatrixserverlib.RoomEventFilter,
wantFullState bool,
device *userapi.Device,
isPeek bool,
) (jr *types.JoinResponse, err error) {
jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events.
@ -404,33 +453,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return
}
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
// user shouldn't see, we check the recent events and remove any prior to the join event of the user
// which is equiv to history_visibility: joined
joinEventIndex := -1
for i := len(recentStreamEvents) - 1; i >= 0; i-- {
ev := recentStreamEvents[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(device.UserID) {
membership, _ := ev.Membership()
if membership == "join" {
joinEventIndex = i
if i > 0 {
// the create event happens before the first join, so we should cut it at that point instead
if recentStreamEvents[i-1].Type() == gomatrixserverlib.MRoomCreate && recentStreamEvents[i-1].StateKeyEquals("") {
joinEventIndex = i - 1
break
}
}
break
}
}
}
if joinEventIndex != -1 {
// cut all events earlier than the join (but not the join itself)
recentStreamEvents = recentStreamEvents[joinEventIndex:]
limited = false // so clients know not to try to backpaginate
}
// Work our way through the timeline events and pick out the event IDs
// of any state events that appear in the timeline. We'll specifically
// exclude them at the next step, so that we don't get duplicate state
@ -474,6 +496,19 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
events := recentEvents
// Only apply history visibility checks if the response is for joined rooms
if !isPeek {
events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
}
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
limited = limited && len(events) == len(recentEvents)
if stateFilter.LazyLoadMembers {
if err != nil {
return nil, err
@ -488,8 +523,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
return jr, nil
}

View file

@ -39,21 +39,13 @@ func (p *SendToDeviceStreamProvider) IncrementalSync(
return from
}
if len(events) > 0 {
// Clean up old send-to-device messages from before this stream position.
if err := p.DB.CleanSendToDeviceUpdates(req.Context, req.Device.UserID, req.Device.ID, from); err != nil {
req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
return from
}
// Add the updates into the sync response.
for _, event := range events {
// skip ignored user events
if _, ok := req.IgnoredUsers.List[event.Sender]; ok {
continue
}
req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)
// Add the updates into the sync response.
for _, event := range events {
// skip ignored user events
if _, ok := req.IgnoredUsers.List[event.Sender]; ok {
continue
}
req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)
}
return lastPos

View file

@ -25,6 +25,11 @@ import (
"sync"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
@ -35,10 +40,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
// RequestPool manages HTTP long-poll connections for /sync
@ -251,6 +252,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
waitingSyncRequests.Inc()
defer waitingSyncRequests.Dec()
// Clean up old send-to-device messages from before this stream position.
// This is needed to avoid sending the same message multiple times
if err = rp.db.CleanSendToDeviceUpdates(syncReq.Context, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Since.SendToDevicePosition); err != nil {
syncReq.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
}
// loop until we get some data
for {
startTime := time.Now()

View file

@ -12,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/producers"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
@ -54,6 +55,16 @@ func (s *syncRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *rsap
return nil
}
func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsapi.QueryMembershipForUserRequest, res *rsapi.QueryMembershipForUserResponse) error {
res.IsRoomForgotten = false
res.RoomExists = true
return nil
}
func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error {
return nil
}
type syncUserAPI struct {
userapi.SyncUserAPI
accounts []userapi.Device
@ -78,10 +89,11 @@ type syncKeyAPI struct {
keyapi.SyncKeyAPI
}
func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
return nil
}
func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
return nil
}
func TestSyncAPIAccessTokens(t *testing.T) {
@ -106,7 +118,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events())
msgs := toNATSMsgs(t, base, room.Events()...)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
testrig.MustPublishMsgs(t, jsctx, msgs...)
@ -199,7 +211,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
// m.room.power_levels
// m.room.join_rules
// m.room.history_visibility
msgs := toNATSMsgs(t, base, room.Events())
msgs := toNATSMsgs(t, base, room.Events()...)
sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
for i, msg := range msgs {
@ -314,6 +326,174 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
}
// This is mainly what Sytest is doing in "test_history_visibility"
func TestMessageHistoryVisibility(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testHistoryVisibility(t, dbType)
})
}
func testHistoryVisibility(t *testing.T, dbType test.DBType) {
type result struct {
seeWithoutJoin bool
seeBeforeJoin bool
seeAfterInvite bool
}
// create the users
alice := test.NewUser(t)
bob := test.NewUser(t)
bobDev := userapi.Device{
ID: "BOBID",
UserID: bob.ID,
AccessToken: "BOD_BEARER_TOKEN",
DisplayName: "BOB",
}
ctx := context.Background()
// check guest and normal user accounts
for _, accType := range []userapi.AccountType{userapi.AccountTypeGuest, userapi.AccountTypeUser} {
testCases := []struct {
historyVisibility gomatrixserverlib.HistoryVisibility
wantResult result
}{
{
historyVisibility: gomatrixserverlib.HistoryVisibilityWorldReadable,
wantResult: result{
seeWithoutJoin: true,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityShared,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: true,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityInvited,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: false,
seeAfterInvite: true,
},
},
{
historyVisibility: gomatrixserverlib.HistoryVisibilityJoined,
wantResult: result{
seeWithoutJoin: false,
seeBeforeJoin: false,
seeAfterInvite: false,
},
},
}
bobDev.AccountType = accType
userType := "guest"
if accType == userapi.AccountTypeUser {
userType = "real user"
}
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
// Use the actual internal roomserver API
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{})
for _, tc := range testCases {
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
t.Run(testname, func(t *testing.T) {
// create a room with the given visibility
room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility))
// send the events/messages to NATS to create the rooms
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)})
eventsToSend := append(room.Events(), beforeJoinEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// There is only one event, we expect only to be able to see this, if the room is world_readable
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
// We only care about the returned events at this point
var res struct {
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
}
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
verifyEventVisible(t, tc.wantResult.seeWithoutJoin, beforeJoinEv, res.Chunk)
// Create invite, a message, join the room and create another message.
inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID))
afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)})
joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID))
msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)})
eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// Verify the messages after/before invite are visible or not
w = httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
// verify results
verifyEventVisible(t, tc.wantResult.seeBeforeJoin, beforeJoinEv, res.Chunk)
verifyEventVisible(t, tc.wantResult.seeAfterInvite, afterInviteEv, res.Chunk)
})
}
}
}
func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []gomatrixserverlib.ClientEvent) {
t.Helper()
if wantVisible {
for _, ev := range chunk {
if ev.EventID == wantVisibleEvent.EventID() {
return
}
}
t.Fatalf("expected to see event %s but didn't: %+v", wantVisibleEvent.EventID(), chunk)
} else {
for _, ev := range chunk {
if ev.EventID == wantVisibleEvent.EventID() {
t.Fatalf("expected not to see event %s: %+v", wantVisibleEvent.EventID(), string(ev.Content))
}
}
}
}
func TestSendToDevice(t *testing.T) {
test.WithAllDatabases(t, testSendToDevice)
}
@ -447,7 +627,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
}
}
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
result := make([]*nats.Msg, len(input))
for i, ev := range input {
var addsStateIDs []string
@ -459,6 +639,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev,
AddsStateEventIDs: addsStateIDs,
HistoryVisibility: ev.Visibility,
},
})
}

View file

@ -21,9 +21,10 @@ import (
"strconv"
"strings"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/api"
)
var (
@ -330,23 +331,23 @@ type Response struct {
NextBatch StreamingToken `json:"next_batch"`
AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"`
} `json:"account_data"`
} `json:"account_data,omitempty"`
Presence struct {
Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"`
} `json:"presence"`
} `json:"presence,omitempty"`
Rooms struct {
Join map[string]JoinResponse `json:"join"`
Peek map[string]JoinResponse `json:"peek"`
Invite map[string]InviteResponse `json:"invite"`
Leave map[string]LeaveResponse `json:"leave"`
} `json:"rooms"`
Join map[string]JoinResponse `json:"join,omitempty"`
Peek map[string]JoinResponse `json:"peek,omitempty"`
Invite map[string]InviteResponse `json:"invite,omitempty"`
Leave map[string]LeaveResponse `json:"leave,omitempty"`
} `json:"rooms,omitempty"`
ToDevice struct {
Events []gomatrixserverlib.SendToDeviceEvent `json:"events"`
} `json:"to_device"`
Events []gomatrixserverlib.SendToDeviceEvent `json:"events,omitempty"`
} `json:"to_device,omitempty"`
DeviceLists struct {
Changed []string `json:"changed,omitempty"`
Left []string `json:"left,omitempty"`
} `json:"device_lists"`
} `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
}

View file

@ -49,7 +49,3 @@ Notifications can be viewed with GET /notifications
If remote user leaves room we no longer receive device updates
Guest users can join guest_access rooms
# You'll be shocked to discover this is flakey too
Inbound /v1/send_join rejects joins from other servers

View file

@ -110,8 +110,6 @@ Newly joined room is included in an incremental sync
User is offline if they set_presence=offline in their sync
Changes to state are included in an incremental sync
A change to displayname should appear in incremental /sync
Current state appears in timeline in private history
Current state appears in timeline in private history with many messages before
Rooms a user is invited to appear in an initial sync
Rooms a user is invited to appear in an incremental sync
Sync can be polled for updates
@ -458,7 +456,6 @@ After changing password, a different session no longer works by default
Read markers appear in incremental v2 /sync
Read markers appear in initial v2 /sync
Read markers can be updated
Local users can peek into world_readable rooms by room ID
We can't peek into rooms with shared history_visibility
We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility
@ -720,4 +717,26 @@ Setting state twice is idempotent
Joining room twice is idempotent
Inbound federation can return missing events for shared visibility
Inbound federation ignores redactions from invalid servers room > v3
Joining room twice is idempotent
Getting messages going forward is limited for a departed room (SPEC-216)
m.room.history_visibility == "shared" allows/forbids appropriately for Guest users
m.room.history_visibility == "invited" allows/forbids appropriately for Guest users
m.room.history_visibility == "default" allows/forbids appropriately for Guest users
m.room.history_visibility == "shared" allows/forbids appropriately for Real users
m.room.history_visibility == "invited" allows/forbids appropriately for Real users
m.room.history_visibility == "default" allows/forbids appropriately for Real users
Guest users can sync from world_readable guest_access rooms if joined
Guest users can sync from shared guest_access rooms if joined
Guest users can sync from invited guest_access rooms if joined
Guest users can sync from joined guest_access rooms if joined
Guest users can sync from default guest_access rooms if joined
Real users can sync from world_readable guest_access rooms if joined
Real users can sync from shared guest_access rooms if joined
Real users can sync from invited guest_access rooms if joined
Real users can sync from joined guest_access rooms if joined
Real users can sync from default guest_access rooms if joined
Only see history_visibility changes on boundaries
Current state appears in timeline in private history
Current state appears in timeline in private history with many messages before
Local users can peek into world_readable rooms by room ID
Newly joined room includes presence in incremental sync

View file

@ -37,10 +37,11 @@ var (
)
type Room struct {
ID string
Version gomatrixserverlib.RoomVersion
preset Preset
creator *User
ID string
Version gomatrixserverlib.RoomVersion
preset Preset
visibility gomatrixserverlib.HistoryVisibility
creator *User
authEvents gomatrixserverlib.AuthEvents
currentState map[string]*gomatrixserverlib.HeaderedEvent
@ -61,6 +62,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9,
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
visibility: gomatrixserverlib.HistoryVisibilityShared,
}
for _, m := range modifiers {
m(t, r)
@ -97,10 +99,14 @@ func (r *Room) insertCreateEvents(t *testing.T) {
fallthrough
case PresetPrivateChat:
joinRule.JoinRule = "invite"
hisVis.HistoryVisibility = "shared"
hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
case PresetPublicChat:
joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared"
hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
}
if r.visibility != "" {
hisVis.HistoryVisibility = r.visibility
}
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
@ -183,7 +189,9 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
}
return ev.Headered(r.Version)
headeredEvent := ev.Headered(r.Version)
headeredEvent.Visibility = r.visibility
return headeredEvent
}
// Add a new event to this room DAG. Not thread-safe.
@ -242,6 +250,12 @@ func RoomPreset(p Preset) roomModifier {
}
}
func RoomHistoryVisibility(vis gomatrixserverlib.HistoryVisibility) roomModifier {
return func(t *testing.T, r *Room) {
r.visibility = vis
}
}
func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
return func(t *testing.T, r *Room) {
r.Version = ver

View file

@ -20,6 +20,7 @@ import (
"sync/atomic"
"testing"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
@ -45,7 +46,8 @@ var (
)
type User struct {
ID string
ID string
accountType api.AccountType
// key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
@ -62,6 +64,12 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve
}
}
func WithAccountType(accountType api.AccountType) UserOpt {
return func(u *User) {
u.accountType = accountType
}
}
func NewUser(t *testing.T, opts ...UserOpt) *User {
counter := atomic.AddInt64(&userIDCounter, 1)
var u User

View file

@ -100,7 +100,7 @@ type ClientUserAPI interface {
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error

View file

@ -94,9 +94,10 @@ func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *Per
util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) {
t.Impl.QueryKeyBackup(ctx, req, res)
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error {
err := t.Impl.QueryKeyBackup(ctx, req, res)
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error {
err := t.Impl.QueryProfile(ctx, req, res)

View file

@ -192,7 +192,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
}
deleteRes := &keyapi.PerformDeleteKeysResponse{}
a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes)
if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
return err
}
if err := deleteRes.Error; err != nil {
return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err)
}
@ -211,10 +213,12 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er
}
var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: userID,
DeviceKeys: deviceKeys,
}, &uploadRes)
}, &uploadRes); err != nil {
return err
}
if uploadRes.Error != nil {
return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error)
}
@ -268,7 +272,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key
var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: req.RequestingUserID,
DeviceKeys: []keyapi.DeviceKeys{
{
@ -279,7 +283,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
},
},
OnlyDisplayNameUpdates: true,
}, &uploadRes)
}, &uploadRes); err != nil {
return err
}
if uploadRes.Error != nil {
return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error)
}
@ -479,7 +485,9 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
}
evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{}
a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes)
if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil {
return err
}
if err := evacuateRes.Error; err != nil {
logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation")
}
@ -538,9 +546,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
if req.Version == "" {
res.BadInput = true
res.Error = "must specify a version to delete"
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
@ -549,9 +554,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = exists
res.Version = req.Version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
// Create metadata
@ -562,9 +564,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = err == nil
res.Version = version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
// Update metadata
@ -575,16 +574,10 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = err == nil
res.Version = req.Version
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
// Upload Keys for a specific version metadata
a.uploadBackupKeys(ctx, req, res)
if res.Error != "" {
return fmt.Errorf(res.Error)
}
return nil
}
@ -627,16 +620,16 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
res.KeyETag = etag
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error {
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
res.Exists = false
return
return nil
}
res.Error = fmt.Sprintf("failed to query key backup: %s", err)
return
return nil
}
res.Algorithm = algorithm
res.AuthData = authData
@ -648,15 +641,16 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
return nil
}
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return
return nil
}
res.Keys = result
return nil
}
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@ -84,11 +83,10 @@ type httpUserInternalAPI struct {
}
func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData")
defer span.Finish()
apiURL := h.apiURL + InputAccountDataPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"InputAccountData", h.apiURL+InputAccountDataPath,
h.httpClient, ctx, req, res,
)
}
func (h *httpUserInternalAPI) PerformAccountCreation(
@ -96,11 +94,10 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
request *api.PerformAccountCreationRequest,
response *api.PerformAccountCreationResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountCreation")
defer span.Finish()
apiURL := h.apiURL + PerformAccountCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformAccountCreation", h.apiURL+PerformAccountCreationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformPasswordUpdate(
@ -108,11 +105,10 @@ func (h *httpUserInternalAPI) PerformPasswordUpdate(
request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformPasswordUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformPasswordUpdate", h.apiURL+PerformPasswordUpdatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformDeviceCreation(
@ -120,11 +116,10 @@ func (h *httpUserInternalAPI) PerformDeviceCreation(
request *api.PerformDeviceCreationRequest,
response *api.PerformDeviceCreationResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceCreation")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformDeviceCreation", h.apiURL+PerformDeviceCreationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformDeviceDeletion(
@ -132,47 +127,54 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion(
request *api.PerformDeviceDeletionRequest,
response *api.PerformDeviceDeletionResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformDeviceDeletion", h.apiURL+PerformDeviceDeletionPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformLastSeenUpdate(
ctx context.Context,
req *api.PerformLastSeenUpdateRequest,
res *api.PerformLastSeenUpdateResponse,
request *api.PerformLastSeenUpdateRequest,
response *api.PerformLastSeenUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLastSeen")
defer span.Finish()
apiURL := h.apiURL + PerformLastSeenUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
return httputil.CallInternalRPCAPI(
"PerformLastSeen", h.apiURL+PerformLastSeenUpdatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformDeviceUpdate(
ctx context.Context,
request *api.PerformDeviceUpdateRequest,
response *api.PerformDeviceUpdateResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformDeviceUpdate", h.apiURL+PerformDeviceUpdatePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountDeactivation")
defer span.Finish()
apiURL := h.apiURL + PerformAccountDeactivationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformAccountDeactivation(
ctx context.Context,
request *api.PerformAccountDeactivationRequest,
response *api.PerformAccountDeactivationResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformAccountDeactivation", h.apiURL+PerformAccountDeactivationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation")
defer span.Finish()
apiURL := h.apiURL + PerformOpenIDTokenCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(
ctx context.Context,
request *api.PerformOpenIDTokenCreationRequest,
response *api.PerformOpenIDTokenCreationResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformOpenIDTokenCreation", h.apiURL+PerformOpenIDTokenCreationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryProfile(
@ -180,11 +182,10 @@ func (h *httpUserInternalAPI) QueryProfile(
request *api.QueryProfileRequest,
response *api.QueryProfileResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryProfile")
defer span.Finish()
apiURL := h.apiURL + QueryProfilePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryProfile", h.apiURL+QueryProfilePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryDeviceInfos(
@ -192,11 +193,10 @@ func (h *httpUserInternalAPI) QueryDeviceInfos(
request *api.QueryDeviceInfosRequest,
response *api.QueryDeviceInfosResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos")
defer span.Finish()
apiURL := h.apiURL + QueryDeviceInfosPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryDeviceInfos", h.apiURL+QueryDeviceInfosPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryAccessToken(
@ -204,72 +204,87 @@ func (h *httpUserInternalAPI) QueryAccessToken(
request *api.QueryAccessTokenRequest,
response *api.QueryAccessTokenResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccessToken")
defer span.Finish()
apiURL := h.apiURL + QueryAccessTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryAccessToken", h.apiURL+QueryAccessTokenPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDevices")
defer span.Finish()
apiURL := h.apiURL + QueryDevicesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryDevices(
ctx context.Context,
request *api.QueryDevicesRequest,
response *api.QueryDevicesResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryDevices", h.apiURL+QueryDevicesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccountData")
defer span.Finish()
apiURL := h.apiURL + QueryAccountDataPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryAccountData(
ctx context.Context,
request *api.QueryAccountDataRequest,
response *api.QueryAccountDataResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountData", h.apiURL+QueryAccountDataPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles")
defer span.Finish()
apiURL := h.apiURL + QuerySearchProfilesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QuerySearchProfiles(
ctx context.Context,
request *api.QuerySearchProfilesRequest,
response *api.QuerySearchProfilesResponse,
) error {
return httputil.CallInternalRPCAPI(
"QuerySearchProfiles", h.apiURL+QuerySearchProfilesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken")
defer span.Finish()
apiURL := h.apiURL + QueryOpenIDTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryOpenIDToken(
ctx context.Context,
request *api.QueryOpenIDTokenRequest,
response *api.QueryOpenIDTokenResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryOpenIDToken", h.apiURL+QueryOpenIDTokenPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup")
defer span.Finish()
apiURL := h.apiURL + PerformKeyBackupPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = err.Error()
}
return nil
}
func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
defer span.Finish()
apiURL := h.apiURL + QueryKeyBackupPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = err.Error()
}
func (h *httpUserInternalAPI) PerformKeyBackup(
ctx context.Context,
request *api.PerformKeyBackupRequest,
response *api.PerformKeyBackupResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformKeyBackup", h.apiURL+PerformKeyBackupPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications")
defer span.Finish()
func (h *httpUserInternalAPI) QueryKeyBackup(
ctx context.Context,
request *api.QueryKeyBackupRequest,
response *api.QueryKeyBackupResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryKeyBackup", h.apiURL+QueryKeyBackupPath,
h.httpClient, ctx, request, response,
)
}
return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res)
func (h *httpUserInternalAPI) QueryNotifications(
ctx context.Context,
request *api.QueryNotificationsRequest,
response *api.QueryNotificationsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryNotifications", h.apiURL+QueryNotificationsPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformPusherSet(
@ -277,27 +292,32 @@ func (h *httpUserInternalAPI) PerformPusherSet(
request *api.PerformPusherSetRequest,
response *struct{},
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet")
defer span.Finish()
apiURL := h.apiURL + PerformPusherSetPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformPusherSet", h.apiURL+PerformPusherSetPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformPusherDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformPusherDeletion(
ctx context.Context,
request *api.PerformPusherDeletionRequest,
response *struct{},
) error {
return httputil.CallInternalRPCAPI(
"PerformPusherDeletion", h.apiURL+PerformPusherDeletionPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers")
defer span.Finish()
apiURL := h.apiURL + QueryPushersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryPushers(
ctx context.Context,
request *api.QueryPushersRequest,
response *api.QueryPushersResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryPushers", h.apiURL+QueryPushersPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformPushRulesPut(
@ -305,89 +325,117 @@ func (h *httpUserInternalAPI) PerformPushRulesPut(
request *api.PerformPushRulesPutRequest,
response *struct{},
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut")
defer span.Finish()
apiURL := h.apiURL + PerformPushRulesPutPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformPushRulesPut", h.apiURL+PerformPushRulesPutPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules")
defer span.Finish()
apiURL := h.apiURL + QueryPushRulesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryPushRules(
ctx context.Context,
request *api.QueryPushRulesRequest,
response *api.QueryPushRulesResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryPushRules", h.apiURL+QueryPushRulesPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetAvatarURLPath)
defer span.Finish()
apiURL := h.apiURL + PerformSetAvatarURLPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) SetAvatarURL(
ctx context.Context,
request *api.PerformSetAvatarURLRequest,
response *api.PerformSetAvatarURLResponse,
) error {
return httputil.CallInternalRPCAPI(
"SetAvatarURL", h.apiURL+PerformSetAvatarURLPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryNumericLocalpartPath)
defer span.Finish()
apiURL := h.apiURL + QueryNumericLocalpartPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, struct{}{}, res)
func (h *httpUserInternalAPI) QueryNumericLocalpart(
ctx context.Context,
response *api.QueryNumericLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
h.httpClient, ctx, &struct{}{}, response,
)
}
func (h *httpUserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountAvailabilityPath)
defer span.Finish()
apiURL := h.apiURL + QueryAccountAvailabilityPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryAccountAvailability(
ctx context.Context,
request *api.QueryAccountAvailabilityRequest,
response *api.QueryAccountAvailabilityResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountAvailability", h.apiURL+QueryAccountAvailabilityPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountByPasswordPath)
defer span.Finish()
apiURL := h.apiURL + QueryAccountByPasswordPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryAccountByPassword(
ctx context.Context,
request *api.QueryAccountByPasswordRequest,
response *api.QueryAccountByPasswordResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountByPassword", h.apiURL+QueryAccountByPasswordPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetDisplayNamePath)
defer span.Finish()
apiURL := h.apiURL + PerformSetDisplayNamePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) SetDisplayName(
ctx context.Context,
request *api.PerformUpdateDisplayNameRequest,
response *struct{},
) error {
return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForThreePIDPath)
defer span.Finish()
apiURL := h.apiURL + QueryLocalpartForThreePIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryLocalpartForThreePID(
ctx context.Context,
request *api.QueryLocalpartForThreePIDRequest,
response *api.QueryLocalpartForThreePIDResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryLocalpartForThreePID", h.apiURL+QueryLocalpartForThreePIDPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, QueryThreePIDsForLocalpartPath)
defer span.Finish()
apiURL := h.apiURL + QueryThreePIDsForLocalpartPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(
ctx context.Context,
request *api.QueryThreePIDsForLocalpartRequest,
response *api.QueryThreePIDsForLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryThreePIDsForLocalpart", h.apiURL+QueryThreePIDsForLocalpartPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetThreePIDPath)
defer span.Finish()
apiURL := h.apiURL + PerformForgetThreePIDPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformForgetThreePID(
ctx context.Context,
request *api.PerformForgetThreePIDRequest,
response *struct{},
) error {
return httputil.CallInternalRPCAPI(
"PerformForgetThreePID", h.apiURL+PerformForgetThreePIDPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveThreePIDAssociationPath)
defer span.Finish()
apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
ctx context.Context,
request *api.PerformSaveThreePIDAssociationRequest,
response *struct{},
) error {
return httputil.CallInternalRPCAPI(
"PerformSaveThreePIDAssociation", h.apiURL+PerformSaveThreePIDAssociationPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -19,7 +19,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
const (
@ -33,11 +32,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenCreation(
request *api.PerformLoginTokenCreationRequest,
response *api.PerformLoginTokenCreationResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation")
defer span.Finish()
apiURL := h.apiURL + PerformLoginTokenCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformLoginTokenCreation", h.apiURL+PerformLoginTokenCreationPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
@ -45,11 +43,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
request *api.PerformLoginTokenDeletionRequest,
response *api.PerformLoginTokenDeletionResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformLoginTokenDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"PerformLoginTokenDeletion", h.apiURL+PerformLoginTokenDeletionPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryLoginToken(
@ -57,9 +54,8 @@ func (h *httpUserInternalAPI) QueryLoginToken(
request *api.QueryLoginTokenRequest,
response *api.QueryLoginTokenResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken")
defer span.Finish()
apiURL := h.apiURL + QueryLoginTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
return httputil.CallInternalRPCAPI(
"QueryLoginToken", h.apiURL+QueryLoginTokenPath,
h.httpClient, ctx, request, response,
)
}

Some files were not shown because too many files have changed in this diff Show more