Merge branch 'master' into cross-compile-docker

This commit is contained in:
Caleb Xavier Berger 2021-01-20 06:35:56 -05:00 committed by GitHub
commit 9aa5cd7c1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 666 additions and 236 deletions

View file

@ -7,7 +7,8 @@ on:
types: [published] types: [published]
env: env:
DOCKER_HUB_USER: matrixdotorg DOCKER_NAMESPACE: matrixdotorg
DOCKER_HUB_USER: dendritegithub
PLATFORMS: linux/amd64,linux/arm64,linux/arm/v7 PLATFORMS: linux/amd64,linux/arm64,linux/arm/v7
jobs: jobs:
@ -37,8 +38,8 @@ jobs:
platforms: ${{ env.PLATFORMS }} platforms: ${{ env.PLATFORMS }}
push: true push: true
tags: | tags: |
${{ env.DOCKER_HUB_USER }}/dendrite-monolith:latest ${{ env.DOCKER_NAMESPACE }}/dendrite-monolith:latest
${{ env.DOCKER_HUB_USER }}/dendrite-monolith:${{ env.RELEASE_VERSION }} ${{ env.DOCKER_NAMESPACE }}/dendrite-monolith:${{ env.RELEASE_VERSION }}
Polylith: Polylith:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -66,5 +67,5 @@ jobs:
platforms: ${{ env.PLATFORMS }} platforms: ${{ env.PLATFORMS }}
push: true push: true
tags: | tags: |
${{ env.DOCKER_HUB_USER }}/dendrite-polylith:latest ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:latest
${{ env.DOCKER_HUB_USER }}/dendrite-polylith:${{ env.RELEASE_VERSION }} ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }}

View file

@ -17,6 +17,8 @@ else
export FLAGS="" export FLAGS=""
fi fi
mkdir -p bin
CGO_ENABLED=1 go build -trimpath -ldflags "$FLAGS" -v -o "bin/" ./cmd/... CGO_ENABLED=1 go build -trimpath -ldflags "$FLAGS" -v -o "bin/" ./cmd/...
CGO_ENABLED=0 GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o bin/main.wasm ./cmd/dendritejs CGO_ENABLED=0 GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o bin/main.wasm ./cmd/dendritejs

View file

@ -22,6 +22,7 @@ type FederationClient interface {
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
} }

View file

@ -244,3 +244,17 @@ func (a *FederationSenderInternalAPI) MSC2836EventRelationships(
} }
return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil
} }
func (a *FederationSenderInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, r)
})
if err != nil {
return res, err
}
return ires.(gomatrixserverlib.MSC2946SpacesResponse), nil
}

View file

@ -33,6 +33,7 @@ const (
FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys" FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys"
FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys" FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys"
FederationSenderEventRelationshipsPath = "/federationsender/client/msc2836eventRelationships" FederationSenderEventRelationshipsPath = "/federationsender/client/msc2836eventRelationships"
FederationSenderSpacesSummaryPath = "/federationsender/client/msc2946spacesSummary"
) )
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
@ -449,3 +450,34 @@ func (h *httpFederationSenderInternalAPI) MSC2836EventRelationships(
} }
return response.Res, nil return response.Res, nil
} }
type spacesReq struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2946SpacesRequest
RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError
}
func (h *httpFederationSenderInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
defer span.Finish()
request := spacesReq{
S: dst,
Req: r,
RoomID: roomID,
}
var response spacesReq
apiURL := h.federationSenderURL + FederationSenderSpacesSummaryPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
}

View file

@ -329,4 +329,26 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route
return util.JSONResponse{Code: http.StatusOK, JSON: request} return util.JSONResponse{Code: http.StatusOK, JSON: request}
}), }),
) )
internalAPIMux.Handle(
FederationSenderSpacesSummaryPath,
httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse {
var request spacesReq
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
} }

11
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc github.com/matrix-org/gomatrixserverlib v0.0.0-20210119115951-bd57c7cff614
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.2 github.com/mattn/go-sqlite3 v1.14.2
@ -33,16 +33,15 @@ require (
github.com/pressly/goose v2.7.0-rc5+incompatible github.com/pressly/goose v2.7.0-rc5+incompatible
github.com/prometheus/client_golang v1.7.1 github.com/prometheus/client_golang v1.7.1
github.com/sirupsen/logrus v1.7.0 github.com/sirupsen/logrus v1.7.0
github.com/tidwall/gjson v1.6.3 github.com/tidwall/gjson v1.6.7
github.com/tidwall/match v1.0.2 // indirect github.com/tidwall/sjson v1.1.4
github.com/tidwall/sjson v1.1.2
github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-client-go v2.25.0+incompatible
github.com/uber/jaeger-lib v2.2.0+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20201006093556-760d9a7fd5ee github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20201006093556-760d9a7fd5ee
go.uber.org/atomic v1.6.0 go.uber.org/atomic v1.6.0
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad
golang.org/x/net v0.0.0-20200528225125-3c3fba18258b golang.org/x/net v0.0.0-20200528225125-3c3fba18258b
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect golang.org/x/sys v0.0.0-20210113181707-4bcb84eeeb78 // indirect
gopkg.in/h2non/bimg.v1 v1.1.4 gopkg.in/h2non/bimg.v1 v1.1.4
gopkg.in/yaml.v2 v2.3.0 gopkg.in/yaml.v2 v2.3.0
) )

29
go.sum
View file

@ -567,8 +567,12 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc h1:n2Hnbg8RZ4102Qmxie1riLkIyrqeqShJUILg1miSmDI= github.com/matrix-org/gomatrixserverlib v0.0.0-20210115150839-9ba5f3e11086 h1:nfGXVXx+cg1iBAWatukPsBe5OKsW+TdmF/qydnt04eg=
github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20210115150839-9ba5f3e11086/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20210115152401-7c4619994337 h1:HJ9iH00PwMDaXsH7vWpO7nRucz+d92QLoH0PNW7hs58=
github.com/matrix-org/gomatrixserverlib v0.0.0-20210115152401-7c4619994337/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20210119115951-bd57c7cff614 h1:X5FP1YOiGmPfpK4IAc8KyX8lOW4nC81/YZPTbOWAyKs=
github.com/matrix-org/gomatrixserverlib v0.0.0-20210119115951-bd57c7cff614/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
@ -810,13 +814,12 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc= github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc=
github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
github.com/tidwall/gjson v1.6.1/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0= github.com/tidwall/gjson v1.6.7 h1:Mb1M9HZCRWEcXQ8ieJo7auYyyiSux6w9XN3AdTpxJrE=
github.com/tidwall/gjson v1.6.3 h1:aHoiiem0dr7GHkW001T1SMTJ7X5PvyekH5WX0whWGnI= github.com/tidwall/gjson v1.6.7/go.mod h1:zeFuBCIqD4sN/gmqBzZ4j7Jd6UcA2Fc56x7QFsv+8fI=
github.com/tidwall/gjson v1.6.3/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0=
github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc= github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc=
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
github.com/tidwall/match v1.0.2 h1:uuqvHuBGSedK7awZ2YoAtpnimfwBGFjHuWLuLqQj+bU= github.com/tidwall/match v1.0.3 h1:FQUVvBImDutD8wJLN6c5eMzWtjgONK9MwIBCOrUJKeE=
github.com/tidwall/match v1.0.2/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= github.com/tidwall/match v1.0.3/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8= github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8=
github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
@ -824,8 +827,8 @@ github.com/tidwall/pretty v1.0.2 h1:Z7S3cePv9Jwm1KwS0513MRaoUe3S01WPbLNV40pwWZU=
github.com/tidwall/pretty v1.0.2/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.2/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/sjson v1.0.3 h1:DeF+0LZqvIt4fKYw41aPB29ZGlvwVkHKktoXJ1YW9Y8= github.com/tidwall/sjson v1.0.3 h1:DeF+0LZqvIt4fKYw41aPB29ZGlvwVkHKktoXJ1YW9Y8=
github.com/tidwall/sjson v1.0.3/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= github.com/tidwall/sjson v1.0.3/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y=
github.com/tidwall/sjson v1.1.2 h1:NC5okI+tQ8OG/oyzchvwXXxRxCV/FVdhODbPKkQ25jQ= github.com/tidwall/sjson v1.1.4 h1:bTSsPLdAYF5QNLSwYsKfBKKTnlGbIuhqL3CpRsjzGhg=
github.com/tidwall/sjson v1.1.2/go.mod h1:SEzaDwxiPzKzNfUEO4HbYF/m4UCSJDsGgNqsS1LvdoY= github.com/tidwall/sjson v1.1.4/go.mod h1:wXpKXu8CtDjKAZ+3DrKY5ROCorDFahq8l0tey/Lx1fg=
github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U= github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U=
github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw= github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw=
@ -906,8 +909,8 @@ golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5 h1:Q7tZBpemrlsc2I7IyODzht
golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY=
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@ -996,8 +999,8 @@ golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80=
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= golang.org/x/sys v0.0.0-20210113181707-4bcb84eeeb78 h1:nVuTkr9L6Bq62qpUqKo/RnZCFfzDBL0bYo6w9OJUqZY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210113181707-4bcb84eeeb78/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View file

@ -17,13 +17,17 @@ package msc2946
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
chttputil "github.com/matrix-org/dendrite/clientapi/httputil" chttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
fs "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api" roomserver "github.com/matrix-org/dendrite/roomserver/api"
@ -40,38 +44,16 @@ const (
ConstSpaceParentEventType = "org.matrix.msc1772.space.parent" ConstSpaceParentEventType = "org.matrix.msc1772.space.parent"
) )
// SpacesRequest is the request body to POST /_matrix/client/r0/rooms/{roomID}/spaces
type SpacesRequest struct {
MaxRoomsPerSpace int `json:"max_rooms_per_space"`
Limit int `json:"limit"`
Batch string `json:"batch"`
}
// Defaults sets the request defaults // Defaults sets the request defaults
func (r *SpacesRequest) Defaults() { func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) {
r.Limit = 100 r.Limit = 100
r.MaxRoomsPerSpace = -1 r.MaxRoomsPerSpace = -1
} }
// SpacesResponse is the response body to POST /_matrix/client/r0/rooms/{roomID}/spaces
type SpacesResponse struct {
NextBatch string `json:"next_batch"`
// Rooms are nodes on the space graph.
Rooms []Room `json:"rooms"`
// Events are edges on the space graph, exclusively m.space.child or m.space.parent events
Events []gomatrixserverlib.ClientEvent `json:"events"`
}
// Room is a node on the space graph
type Room struct {
gomatrixserverlib.PublicRoom
NumRefs int `json:"num_refs"`
RoomType string `json:"room_type"`
}
// Enable this MSC // Enable this MSC
func Enable( func Enable(
base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
fsAPI fs.FederationSenderInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
) error { ) error {
db, err := NewDatabase(&base.Cfg.MSCs.Database) db, err := NewDatabase(&base.Cfg.MSCs.Database)
if err != nil { if err != nil {
@ -89,12 +71,69 @@ func Enable(
}) })
base.PublicClientAPIMux.Handle("/unstable/rooms/{roomID}/spaces", base.PublicClientAPIMux.Handle("/unstable/rooms/{roomID}/spaces",
httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI)), httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
base.PublicFederationAPIMux.Handle("/unstable/spaces/{roomID}", httputil.MakeExternalAPI(
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), base.Cfg.Global.ServerName, keyRing,
)
if fedReq == nil {
return errResp
}
// Extract the room ID from the request. Sanity check request data.
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID := params["roomID"]
return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName)
},
)).Methods(http.MethodPost, http.MethodOptions)
return nil return nil
} }
func spacesHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { func federatedSpacesHandler(
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database,
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
thisServer gomatrixserverlib.ServerName,
) util.JSONResponse {
inMemoryBatchCache := make(map[string]set)
var r gomatrixserverlib.MSC2946SpacesRequest
Defaults(&r)
if err := json.Unmarshal(fedReq.Content(), &r); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
}
if r.Limit > 100 {
r.Limit = 100
}
w := walker{
req: &r,
rootRoomID: roomID,
serverName: fedReq.Origin(),
thisServer: thisServer,
ctx: ctx,
db: db,
rsAPI: rsAPI,
fsAPI: fsAPI,
inMemoryBatchCache: inMemoryBatchCache,
}
res := w.walk()
return util.JSONResponse{
Code: 200,
JSON: res,
}
}
func spacesHandler(
db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
thisServer gomatrixserverlib.ServerName,
) func(*http.Request, *userapi.Device) util.JSONResponse {
return func(req *http.Request, device *userapi.Device) util.JSONResponse { return func(req *http.Request, device *userapi.Device) util.JSONResponse {
inMemoryBatchCache := make(map[string]set) inMemoryBatchCache := make(map[string]set)
// Extract the room ID from the request. Sanity check request data. // Extract the room ID from the request. Sanity check request data.
@ -103,8 +142,8 @@ func spacesHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*ht
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
roomID := params["roomID"] roomID := params["roomID"]
var r SpacesRequest var r gomatrixserverlib.MSC2946SpacesRequest
r.Defaults() Defaults(&r)
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil { if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr return *resErr
} }
@ -115,10 +154,12 @@ func spacesHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*ht
req: &r, req: &r,
rootRoomID: roomID, rootRoomID: roomID,
caller: device, caller: device,
thisServer: thisServer,
ctx: req.Context(), ctx: req.Context(),
db: db, db: db,
rsAPI: rsAPI, rsAPI: rsAPI,
fsAPI: fsAPI,
inMemoryBatchCache: inMemoryBatchCache, inMemoryBatchCache: inMemoryBatchCache,
} }
res := w.walk() res := w.walk()
@ -130,11 +171,14 @@ func spacesHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*ht
} }
type walker struct { type walker struct {
req *SpacesRequest req *gomatrixserverlib.MSC2946SpacesRequest
rootRoomID string rootRoomID string
caller *userapi.Device caller *userapi.Device
serverName gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName
db Database db Database
rsAPI roomserver.RoomserverInternalAPI rsAPI roomserver.RoomserverInternalAPI
fsAPI fs.FederationSenderInternalAPI
ctx context.Context ctx context.Context
// user ID|device ID|batch_num => event/room IDs sent to client // user ID|device ID|batch_num => event/room IDs sent to client
@ -142,10 +186,26 @@ type walker struct {
mu sync.Mutex mu sync.Mutex
} }
func (w *walker) roomIsExcluded(roomID string) bool {
for _, exclRoom := range w.req.ExcludeRooms {
if exclRoom == roomID {
return true
}
}
return false
}
func (w *walker) callerID() string {
if w.caller != nil {
return w.caller.UserID + "|" + w.caller.ID
}
return string(w.serverName)
}
func (w *walker) alreadySent(id string) bool { func (w *walker) alreadySent(id string) bool {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
m, ok := w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID] m, ok := w.inMemoryBatchCache[w.callerID()]
if !ok { if !ok {
return false return false
} }
@ -155,17 +215,17 @@ func (w *walker) alreadySent(id string) bool {
func (w *walker) markSent(id string) { func (w *walker) markSent(id string) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
m := w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID] m := w.inMemoryBatchCache[w.callerID()]
if m == nil { if m == nil {
m = make(set) m = make(set)
} }
m[id] = true m[id] = true
w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID] = m w.inMemoryBatchCache[w.callerID()] = m
} }
// nolint:gocyclo // nolint:gocyclo
func (w *walker) walk() *SpacesResponse { func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse {
var res SpacesResponse var res gomatrixserverlib.MSC2946SpacesResponse
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
unvisited := []string{w.rootRoomID} unvisited := []string{w.rootRoomID}
processed := make(set) processed := make(set)
@ -178,9 +238,20 @@ func (w *walker) walk() *SpacesResponse {
} }
// Mark this room as processed. // Mark this room as processed.
processed[roomID] = true processed[roomID] = true
// Is the caller currently joined to the room or is the room `world_readable` // Is the caller currently joined to the room or is the room `world_readable`
// If no, skip this room. If yes, continue. // If no, skip this room. If yes, continue.
if !w.authorised(roomID) { if !w.roomExists(roomID) || !w.authorised(roomID) {
// attempt to query this room over federation, as either we've never heard of it before
// or we've left it and hence are not authorised (but info may be exposed regardless)
fedRes, err := w.federatedRoomInfo(roomID)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces")
continue
}
if fedRes != nil {
res = combineResponses(res, *fedRes)
}
continue continue
} }
// Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get // Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get
@ -194,7 +265,7 @@ func (w *walker) walk() *SpacesResponse {
// If this room has not ever been in `rooms` (across multiple requests), extract the // If this room has not ever been in `rooms` (across multiple requests), extract the
// `PublicRoomsChunk` for this room. // `PublicRoomsChunk` for this room.
if !w.alreadySent(roomID) { if !w.alreadySent(roomID) && !w.roomIsExcluded(roomID) {
pubRoom := w.publicRoomsChunk(roomID) pubRoom := w.publicRoomsChunk(roomID)
roomType := "" roomType := ""
create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "") create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "")
@ -204,11 +275,12 @@ func (w *walker) walk() *SpacesResponse {
} }
// Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`. // Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`.
res.Rooms = append(res.Rooms, Room{ res.Rooms = append(res.Rooms, gomatrixserverlib.MSC2946Room{
PublicRoom: *pubRoom, PublicRoom: *pubRoom,
NumRefs: refs.len(), NumRefs: refs.len(),
RoomType: roomType, RoomType: roomType,
}) })
w.markSent(roomID)
} }
uniqueRooms := make(set) uniqueRooms := make(set)
@ -218,9 +290,11 @@ func (w *walker) walk() *SpacesResponse {
if w.rootRoomID == roomID { if w.rootRoomID == roomID {
for _, ev := range refs.events() { for _, ev := range refs.events() {
if !w.alreadySent(ev.EventID()) { if !w.alreadySent(ev.EventID()) {
res.Events = append(res.Events, gomatrixserverlib.HeaderedToClientEvent( strip := stripped(ev.Event)
ev, gomatrixserverlib.FormatAll, if strip == nil {
)) continue
}
res.Events = append(res.Events, *strip)
uniqueRooms[ev.RoomID()] = true uniqueRooms[ev.RoomID()] = true
uniqueRooms[SpaceTarget(ev)] = true uniqueRooms[SpaceTarget(ev)] = true
w.markSent(ev.EventID()) w.markSent(ev.EventID())
@ -240,9 +314,16 @@ func (w *walker) walk() *SpacesResponse {
if w.alreadySent(ev.EventID()) { if w.alreadySent(ev.EventID()) {
continue continue
} }
res.Events = append(res.Events, gomatrixserverlib.HeaderedToClientEvent( // Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still
ev, gomatrixserverlib.FormatAll, // want to catch arrows which point to excluded rooms.
)) if w.roomIsExcluded(ev.RoomID()) {
continue
}
strip := stripped(ev.Event)
if strip == nil {
continue
}
res.Events = append(res.Events, *strip)
uniqueRooms[ev.RoomID()] = true uniqueRooms[ev.RoomID()] = true
uniqueRooms[SpaceTarget(ev)] = true uniqueRooms[SpaceTarget(ev)] = true
w.markSent(ev.EventID()) w.markSent(ev.EventID())
@ -289,8 +370,120 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom {
return &pubRooms[0] return &pubRooms[0]
} }
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
// unsuccessful.
func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
// only do federated requests for client requests
if w.caller == nil {
return nil, nil
}
// extract events which point to this room ID and extract their vias
events, err := w.db.References(w.ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get References events: %w", err)
}
vias := make(set)
for _, ev := range events {
if ev.StateKeyEquals(roomID) {
// event points at this room, extract vias
content := struct {
Vias []string `json:"via"`
}{}
if err = json.Unmarshal(ev.Content(), &content); err != nil {
continue // silently ignore corrupted state events
}
for _, v := range content.Vias {
vias[v] = true
}
}
}
util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias)
ctx := context.Background()
// query more of the spaces graph using these servers
for serverName := range vias {
if serverName == string(w.thisServer) {
continue
}
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{
Limit: w.req.Limit,
MaxRoomsPerSpace: w.req.MaxRoomsPerSpace,
})
if err != nil {
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
continue
}
return &res, nil
}
return nil, nil
}
func (w *walker) roomExists(roomID string) bool {
var queryRes roomserver.QueryServerJoinedToRoomResponse
err := w.rsAPI.QueryServerJoinedToRoom(w.ctx, &roomserver.QueryServerJoinedToRoomRequest{
RoomID: roomID,
ServerName: w.thisServer,
}, &queryRes)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryServerJoinedToRoom")
return false
}
// if the room exists but we aren't in the room then we might have stale data so we want to fetch
// it fresh via federation
return queryRes.RoomExists && queryRes.IsInRoom
}
// authorised returns true iff the user is joined this room or the room is world_readable // authorised returns true iff the user is joined this room or the room is world_readable
func (w *walker) authorised(roomID string) bool { func (w *walker) authorised(roomID string) bool {
if w.caller != nil {
return w.authorisedUser(roomID)
}
return w.authorisedServer(roomID)
}
// authorisedServer returns true iff the server is joined this room or the room is world_readable
func (w *walker) authorisedServer(roomID string) bool {
// Check history visibility first
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomHistoryVisibility,
StateKey: "",
}
var queryRoomRes roomserver.QueryCurrentStateResponse
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{
hisVisTuple,
},
}, &queryRoomRes)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryCurrentState")
return false
}
hisVisEv := queryRoomRes.StateEvents[hisVisTuple]
if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" {
return true
}
}
// check if server is joined to the room
var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
err = w.fsAPI.QueryJoinedHostServerNamesInRoom(w.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: roomID,
}, &queryRes)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom")
return false
}
for _, srv := range queryRes.ServerNames {
if srv == w.serverName {
return true
}
}
return false
}
// authorisedUser returns true iff the user is joined this room or the room is world_readable
func (w *walker) authorisedUser(roomID string) bool {
hisVisTuple := gomatrixserverlib.StateKeyTuple{ hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomHistoryVisibility, EventType: gomatrixserverlib.MRoomHistoryVisibility,
StateKey: "", StateKey: "",
@ -374,3 +567,41 @@ func (el eventLookup) events() (events []*gomatrixserverlib.HeaderedEvent) {
} }
type set map[string]bool type set map[string]bool
func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent {
if ev.StateKey() == nil {
return nil
}
return &gomatrixserverlib.MSC2946StrippedEvent{
Type: ev.Type(),
StateKey: *ev.StateKey(),
Content: ev.Content(),
Sender: ev.Sender(),
RoomID: ev.RoomID(),
}
}
func combineResponses(local, remote gomatrixserverlib.MSC2946SpacesResponse) gomatrixserverlib.MSC2946SpacesResponse {
knownRooms := make(set)
for _, room := range local.Rooms {
knownRooms[room.RoomID] = true
}
knownEvents := make(set)
for _, event := range local.Events {
knownEvents[event.RoomID+event.Type+event.StateKey] = true
}
// mux in remote entries if and only if they aren't present already
for _, room := range remote.Rooms {
if knownRooms[room.RoomID] {
continue
}
local.Rooms = append(local.Rooms, room)
}
for _, event := range remote.Events {
if knownEvents[event.RoomID+event.Type+event.StateKey] {
continue
}
local.Events = append(local.Events, event)
}
return local
}

View file

@ -41,6 +41,7 @@ var (
client = &http.Client{ client = &http.Client{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
} }
roomVer = gomatrixserverlib.RoomVersionV6
) )
// Basic sanity check of MSC2946 logic. Tests a single room with a few state events // Basic sanity check of MSC2946 logic. Tests a single room with a few state events
@ -269,13 +270,13 @@ func TestMSC2946(t *testing.T) {
}) })
} }
func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2946.SpacesRequest { func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest {
t.Helper() t.Helper()
b, err := json.Marshal(jsonBody) b, err := json.Marshal(jsonBody)
if err != nil { if err != nil {
t.Fatalf("Failed to marshal request: %s", err) t.Fatalf("Failed to marshal request: %s", err)
} }
var r msc2946.SpacesRequest var r gomatrixserverlib.MSC2946SpacesRequest
if err := json.Unmarshal(b, &r); err != nil { if err := json.Unmarshal(b, &r); err != nil {
t.Fatalf("Failed to unmarshal request: %s", err) t.Fatalf("Failed to unmarshal request: %s", err)
} }
@ -299,10 +300,10 @@ func runServer(t *testing.T, router *mux.Router) func() {
} }
} }
func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *msc2946.SpacesRequest) *msc2946.SpacesResponse { func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse {
t.Helper() t.Helper()
var r msc2946.SpacesRequest var r gomatrixserverlib.MSC2946SpacesRequest
r.Defaults() msc2946.Defaults(&r)
data, err := json.Marshal(req) data, err := json.Marshal(req)
if err != nil { if err != nil {
t.Fatalf("failed to marshal request: %s", err) t.Fatalf("failed to marshal request: %s", err)
@ -324,7 +325,7 @@ func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *m
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body)) t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
} }
if res.StatusCode == 200 { if res.StatusCode == 200 {
var result msc2946.SpacesResponse var result gomatrixserverlib.MSC2946SpacesResponse
body, err := ioutil.ReadAll(res.Body) body, err := ioutil.ReadAll(res.Body)
if err != nil { if err != nil {
t.Fatalf("response 200 OK but failed to read response body: %s", err) t.Fatalf("response 200 OK but failed to read response body: %s", err)
@ -400,6 +401,12 @@ type testRoomserverAPI struct {
pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string
} }
func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error {
res.IsInRoom = true
res.RoomExists = true
return nil
}
func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error { func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error {
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
for _, roomID := range req.RoomIDs { for _, roomID := range req.RoomIDs {
@ -452,7 +459,7 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve
PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
} }
err := msc2946.Enable(base, rsAPI, userAPI) err := msc2946.Enable(base, rsAPI, userAPI, nil, nil)
if err != nil { if err != nil {
t.Fatalf("failed to enable MSC2946: %s", err) t.Fatalf("failed to enable MSC2946: %s", err)
} }
@ -472,7 +479,6 @@ type fledglingEvent struct {
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
t.Helper() t.Helper()
roomVer := gomatrixserverlib.RoomVersionV6
seed := make([]byte, ed25519.SeedSize) // zero seed seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed) key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.EventBuilder{ eb := gomatrixserverlib.EventBuilder{

View file

@ -41,7 +41,7 @@ func EnableMSC(base *setup.BaseDendrite, monolith *setup.Monolith, msc string) e
case "msc2836": case "msc2836":
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing) return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing)
case "msc2946": case "msc2946":
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI) return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationSenderAPI, monolith.KeyRing)
default: default:
return fmt.Errorf("EnableMSC: unknown msc '%s'", msc) return fmt.Errorf("EnableMSC: unknown msc '%s'", msc)
} }

View file

@ -367,7 +367,6 @@ func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) typ
Timeout: 1 * time.Minute, Timeout: 1 * time.Minute,
Since: since, Since: since,
WantFullState: false, WantFullState: false,
Limit: 20,
Log: util.GetLogger(context.TODO()), Log: util.GetLogger(context.TODO()),
Context: context.TODO(), Context: context.TODO(),
} }

View file

@ -235,12 +235,15 @@ func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start, clientEvents []gomatrixserverlib.ClientEvent, start,
end types.TopologyToken, err error, end types.TopologyToken, err error,
) { ) {
eventFilter := gomatrixserverlib.DefaultRoomEventFilter()
eventFilter.Limit = r.limit
// Retrieve the events from the local database. // Retrieve the events from the local database.
var streamEvents []types.StreamEvent var streamEvents []types.StreamEvent
if r.fromStream != nil { if r.fromStream != nil {
toStream := r.to.StreamToken() toStream := r.to.StreamToken()
streamEvents, err = r.db.GetEventsInStreamingRange( streamEvents, err = r.db.GetEventsInStreamingRange(
r.ctx, r.fromStream, &toStream, r.roomID, r.limit, r.backwardOrdering, r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering,
) )
} else { } else {
streamEvents, err = r.db.GetEventsInTopologicalRange( streamEvents, err = r.db.GetEventsInTopologicalRange(

View file

@ -40,7 +40,7 @@ type Database interface {
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error) GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
@ -105,7 +105,7 @@ type Database interface {
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, eventFilter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event. // EventPositionInTopology returns the depth and stream position of the given event.

View file

@ -83,7 +83,7 @@ func (s *filterStatements) SelectFilter(
} }
// Unmarshal JSON into Filter struct // Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter filter := gomatrixserverlib.DefaultFilter()
if err = json.Unmarshal(filterData, &filter); err != nil { if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err return nil, err
} }

View file

@ -84,17 +84,29 @@ const selectEventsSQL = "" +
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4" " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id DESC LIMIT $8"
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" ORDER BY id DESC LIMIT $4" " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id DESC LIMIT $8"
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4" " AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id ASC LIMIT $8"
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events" "SELECT MAX(id) FROM syncapi_output_room_events"
@ -322,7 +334,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// from sync. // from sync.
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) { ) ([]types.StreamEvent, bool, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
@ -331,7 +343,14 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} else { } else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
} }
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit+1) rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit+1,
)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
@ -350,7 +369,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} }
// we queried for 1 more than the limit, so if we returned one more mark limited=true // we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false limited := false
if len(events) > limit { if len(events) > eventFilter.Limit {
limited = true limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder { if chronologicalOrder {
@ -367,10 +386,17 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
// from a given position, up to a maximum of 'limit'. // from a given position, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) SelectEarlyEvents( func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -110,8 +110,8 @@ func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, mem
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
} }
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, limit, chronologicalOrder, onlySyncEvents) return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
} }
func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
@ -151,7 +151,7 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse
func (d *Database) GetEventsInStreamingRange( func (d *Database) GetEventsInStreamingRange(
ctx context.Context, ctx context.Context,
from, to *types.StreamingToken, from, to *types.StreamingToken,
roomID string, limit int, roomID string, eventFilter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool, backwardOrdering bool,
) (events []types.StreamEvent, err error) { ) (events []types.StreamEvent, err error) {
r := types.Range{ r := types.Range{
@ -162,14 +162,14 @@ func (d *Database) GetEventsInStreamingRange(
if backwardOrdering { if backwardOrdering {
// When using backward ordering, we want the most recent events first. // When using backward ordering, we want the most recent events first.
if events, _, err = d.OutputEvents.SelectRecentEvents( if events, _, err = d.OutputEvents.SelectRecentEvents(
ctx, nil, roomID, r, limit, false, false, ctx, nil, roomID, r, eventFilter, false, false,
); err != nil { ); err != nil {
return return
} }
} else { } else {
// When using forward ordering, we want the least recent events first. // When using forward ordering, we want the least recent events first.
if events, err = d.OutputEvents.SelectEarlyEvents( if events, err = d.OutputEvents.SelectEarlyEvents(
ctx, nil, roomID, r, limit, ctx, nil, roomID, r, eventFilter,
); err != nil { ); err != nil {
return return
} }

View file

@ -82,7 +82,7 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string, userID, roomID, dataType string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, err error) {
pos, err = s.streamIDStatements.nextStreamID(ctx, txn) pos, err = s.streamIDStatements.nextAccountDataID(ctx, txn)
if err != nil { if err != nil {
return return
} }

View file

@ -19,6 +19,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -66,13 +67,8 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
" AND ( $2 IS NULL OR sender IN ($2) )" + // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
" AND ( $3 IS NULL OR NOT(sender IN ($3)) )" +
" AND ( $4 IS NULL OR type IN ($4) )" +
" AND ( $5 IS NULL OR NOT(type IN ($5)) )" +
" AND ( $6 IS NULL OR contains_url = $6 )" +
" LIMIT $7"
const selectJoinedUsersSQL = "" + const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
@ -95,7 +91,6 @@ type currentRoomStateStatements struct {
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
DeleteRoomStateForRoomStmt *sql.Stmt DeleteRoomStateForRoomStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
@ -121,9 +116,6 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
return nil, err
}
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err return nil, err
} }
@ -185,17 +177,22 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
// CurrentState returns all the current state events for the given room. // CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) SelectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) stmt, params, err := prepareWithFilters(
rows, err := stmt.QueryContext(ctx, roomID, s.db, txn, selectCurrentStateSQL,
nil, // FIXME: pq.StringArray(stateFilterPart.Senders), []interface{}{
nil, // FIXME: pq.StringArray(stateFilterPart.NotSenders), roomID,
nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), },
nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), stateFilter.Senders, stateFilter.NotSenders,
stateFilterPart.ContainsURL, stateFilter.Types, stateFilter.NotTypes,
stateFilterPart.Limit, stateFilter.Limit, FilterOrderNone,
) )
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -87,7 +87,7 @@ func (s *filterStatements) SelectFilter(
} }
// Unmarshal JSON into Filter struct // Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter filter := gomatrixserverlib.DefaultFilter()
if err = json.Unmarshal(filterData, &filter); err != nil { if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err return nil, err
} }

View file

@ -0,0 +1,76 @@
package sqlite3
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
type FilterOrder int
const (
FilterOrderNone = iota
FilterOrderAsc
FilterOrderDesc
)
// prepareWithFilters returns a prepared statement with the
// relevant filters included. It also includes an []interface{}
// list of all the relevant parameters to pass straight to
// QueryContext, QueryRowContext etc.
// We don't take the filter object directly here because the
// fields might come from either a StateFilter or an EventFilter,
// and it's easier just to have the caller extract the relevant
// parts.
func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string,
limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) {
offset := len(params)
if count := len(senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
params, offset = append(params, v), offset+1
}
}
if count := len(notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
params, offset = append(params, v), offset+1
}
}
if count := len(types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
params, offset = append(params, v), offset+1
}
}
if count := len(nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
params, offset = append(params, v), offset+1
}
}
switch order {
case FilterOrderAsc:
query += " ORDER BY id ASC"
case FilterOrderDesc:
query += " ORDER BY id DESC"
}
query += fmt.Sprintf(" LIMIT $%d", offset+1)
params = append(params, limit)
var stmt *sql.Stmt
var err error
if txn != nil {
stmt, err = txn.Prepare(query)
} else {
stmt, err = db.Prepare(query)
}
if err != nil {
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
}
return stmt, params, nil
}

View file

@ -93,7 +93,7 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent( func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent, ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) { ) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err = s.streamIDStatements.nextInviteID(ctx, txn)
if err != nil { if err != nil {
return return
} }
@ -119,7 +119,7 @@ func (s *inviteEventsStatements) InsertInviteEvent(
func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEventID string, ctx context.Context, txn *sql.Tx, inviteEventID string,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err := s.streamIDStatements.nextInviteID(ctx, txn)
if err != nil { if err != nil {
return streamPos, err return streamPos, err
} }

View file

@ -19,6 +19,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"sort" "sort"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -60,18 +61,18 @@ const selectEventsSQL = "" +
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3"
" ORDER BY id DESC LIMIT $4" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
" ORDER BY id DESC LIMIT $4" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3"
" ORDER BY id ASC LIMIT $4" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events" "SELECT MAX(id) FROM syncapi_output_room_events"
@ -79,45 +80,24 @@ const selectMaxEventIDSQL = "" +
const updateEventJSONSQL = "" + const updateEventJSONSQL = "" +
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
/*
$1 = oldPos,
$2 = newPos,
$3 = pq.StringArray(stateFilterPart.Senders),
$4 = pq.StringArray(stateFilterPart.NotSenders),
$5 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
$6 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
$7 = stateFilterPart.ContainsURL,
$8 = stateFilterPart.Limit,
*/
const selectStateInRangeSQL = "" + const selectStateInRangeSQL = "" +
"SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" + // old/new pos " WHERE (id > $1 AND id <= $2)" +
" AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
/* " AND ( $3 IS NULL OR sender IN ($3) )" + // sender // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
" AND ( $4 IS NULL OR NOT(sender IN ($4)) )" + // not sender
" AND ( $5 IS NULL OR type IN ($5) )" + // type
" AND ( $6 IS NULL OR NOT(type IN ($6)) )" + // not type
" AND ( $7 IS NULL OR contains_url = $7)" + // contains URL? */
" ORDER BY id ASC" +
" LIMIT $8" // limit
const deleteEventsForRoomSQL = "" + const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1" "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt updateEventJSONStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt deleteEventsForRoomStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
} }
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
@ -138,18 +118,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil { if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil {
return nil, err return nil, err
} }
if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
return nil, err
}
if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil {
return nil, err
}
if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil {
return nil, err
}
if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
return nil, err
}
if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil {
return nil, err return nil, err
} }
@ -173,19 +141,22 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) SelectStateInRange( func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range, ctx context.Context, txn *sql.Tx, r types.Range,
stateFilterPart *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) stmt, params, err := prepareWithFilters(
s.db, txn, selectStateInRangeSQL,
rows, err := stmt.QueryContext( []interface{}{
ctx, r.Low(), r.High(), r.Low(), r.High(),
/*pq.StringArray(stateFilterPart.Senders), },
pq.StringArray(stateFilterPart.NotSenders), stateFilter.Senders, stateFilter.NotSenders,
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), stateFilter.Types, stateFilter.NotTypes,
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), stateFilter.Limit, FilterOrderAsc,
stateFilterPart.ContainsURL,*/
stateFilterPart.Limit,
) )
if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -298,16 +269,21 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err return 0, err
} }
addStateJSON, err := json.Marshal(addState) var addStateJSON, removeStateJSON []byte
if err != nil { if len(addState) > 0 {
return 0, err addStateJSON, err = json.Marshal(addState)
} }
removeStateJSON, err := json.Marshal(removeState)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("json.Marshal(addState): %w", err)
}
if len(removeState) > 0 {
removeStateJSON, err = json.Marshal(removeState)
}
if err != nil {
return 0, fmt.Errorf("json.Marshal(removeState): %w", err)
} }
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -333,17 +309,30 @@ func (s *outputRoomEventsStatements) InsertEvent(
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) { ) ([]types.StreamEvent, bool, error) {
var stmt *sql.Stmt var query string
if onlySyncEvents { if onlySyncEvents {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) query = selectRecentEventsForSyncSQL
} else { } else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) query = selectRecentEventsSQL
} }
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit+1) stmt, params, err := prepareWithFilters(
s.db, txn, query,
[]interface{}{
roomID, r.Low(), r.High(),
},
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
eventFilter.Limit+1, FilterOrderDesc,
)
if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
@ -362,7 +351,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} }
// we queried for 1 more than the limit, so if we returned one more mark limited=true // we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false limited := false
if len(events) > limit { if len(events) > eventFilter.Limit {
limited = true limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder { if chronologicalOrder {
@ -376,10 +365,21 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
func (s *outputRoomEventsStatements) SelectEarlyEvents( func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) stmt, params, err := prepareWithFilters(
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) s.db, txn, selectEarlyEventsSQL,
[]interface{}{
roomID, r.Low(), r.High(),
},
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
eventFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -108,7 +108,7 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks
func (s *peekStatements) InsertPeek( func (s *peekStatements) InsertPeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) { ) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil { if err != nil {
return return
} }
@ -120,7 +120,7 @@ func (s *peekStatements) InsertPeek(
func (s *peekStatements) DeletePeek( func (s *peekStatements) DeletePeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) { ) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil { if err != nil {
return return
} }
@ -131,7 +131,7 @@ func (s *peekStatements) DeletePeek(
func (s *peekStatements) DeletePeeks( func (s *peekStatements) DeletePeeks(
ctx context.Context, txn *sql.Tx, roomID, userID string, ctx context.Context, txn *sql.Tx, roomID, userID string,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -20,6 +20,10 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0) INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
ON CONFLICT DO NOTHING;
` `
const increaseStreamIDStmt = "" + const increaseStreamIDStmt = "" +
@ -49,7 +53,7 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
return return
} }
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
@ -68,3 +72,23 @@ func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (po
err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos) err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return
}
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return
}

View file

@ -56,9 +56,9 @@ type Events interface {
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.

View file

@ -48,13 +48,14 @@ func (p *PDUStreamProvider) CompleteSync(
return from return from
} }
stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request stateFilter := req.Filter.Room.State
eventFilter := req.Filter.Room.Timeline
// Build up a /sync response. Add joined rooms. // Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
var jr *types.JoinResponse var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync( jr, err = p.getJoinResponseForCompleteSync(
ctx, roomID, r, &stateFilter, req.Limit, req.Device, ctx, roomID, r, &stateFilter, &eventFilter, req.Device,
) )
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -74,7 +75,7 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted { if !peek.Deleted {
var jr *types.JoinResponse var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync( jr, err = p.getJoinResponseForCompleteSync(
ctx, peek.RoomID, r, &stateFilter, req.Limit, req.Device, ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.Device,
) )
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -104,8 +105,8 @@ func (p *PDUStreamProvider) IncrementalSync(
var stateDeltas []types.StateDelta var stateDeltas []types.StateDelta
var joinedRooms []string var joinedRooms []string
// TODO: use filter provided in request stateFilter := req.Filter.Room.State
stateFilter := gomatrixserverlib.DefaultStateFilter() eventFilter := req.Filter.Room.Timeline
if req.WantFullState { if req.WantFullState {
if stateDeltas, joinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, joinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
@ -124,7 +125,7 @@ func (p *PDUStreamProvider) IncrementalSync(
} }
for _, delta := range stateDeltas { for _, delta := range stateDeltas {
if err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, req.Limit, req.Response); err != nil { if err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return newPos return newPos
} }
@ -138,7 +139,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
device *userapi.Device, device *userapi.Device,
r types.Range, r types.Range,
delta types.StateDelta, delta types.StateDelta,
numRecentEventsPerRoom int, eventFilter *gomatrixserverlib.RoomEventFilter,
res *types.Response, res *types.Response,
) error { ) error {
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
@ -152,7 +153,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
recentStreamEvents, limited, err := p.DB.RecentEvents( recentStreamEvents, limited, err := p.DB.RecentEvents(
ctx, delta.RoomID, r, ctx, delta.RoomID, r,
numRecentEventsPerRoom, true, true, eventFilter, true, true,
) )
if err != nil { if err != nil {
return err return err
@ -209,7 +210,8 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
roomID string, roomID string,
r types.Range, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
numRecentEventsPerRoom int, device *userapi.Device, eventFilter *gomatrixserverlib.RoomEventFilter,
device *userapi.Device,
) (jr *types.JoinResponse, err error) { ) (jr *types.JoinResponse, err error) {
var stateEvents []*gomatrixserverlib.HeaderedEvent var stateEvents []*gomatrixserverlib.HeaderedEvent
stateEvents, err = p.DB.CurrentState(ctx, roomID, stateFilter) stateEvents, err = p.DB.CurrentState(ctx, roomID, stateFilter)
@ -221,7 +223,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
var recentStreamEvents []types.StreamEvent var recentStreamEvents []types.StreamEvent
var limited bool var limited bool
recentStreamEvents, limited, err = p.DB.RecentEvents( recentStreamEvents, limited, err = p.DB.RecentEvents(
ctx, roomID, r, numRecentEventsPerRoom, true, true, ctx, roomID, r, eventFilter, true, true,
) )
if err != nil { if err != nil {
return return

View file

@ -16,6 +16,7 @@ package sync
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -31,14 +32,6 @@ import (
const defaultSyncTimeout = time.Duration(0) const defaultSyncTimeout = time.Duration(0)
const DefaultTimelineLimit = 20 const DefaultTimelineLimit = 20
type filter struct {
Room struct {
Timeline struct {
Limit *int `json:"limit"`
} `json:"timeline"`
} `json:"room"`
}
func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Database) (*types.SyncRequest, error) { func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Database) (*types.SyncRequest, error) {
timeout := getTimeout(req.URL.Query().Get("timeout")) timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state") fullState := req.URL.Query().Get("full_state")
@ -51,41 +44,37 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
return nil, err return nil, err
} }
} }
timelineLimit := DefaultTimelineLimit
// TODO: read from stored filters too // TODO: read from stored filters too
filter := gomatrixserverlib.DefaultFilter()
filterQuery := req.URL.Query().Get("filter") filterQuery := req.URL.Query().Get("filter")
if filterQuery != "" { if filterQuery != "" {
if filterQuery[0] == '{' { if filterQuery[0] == '{' {
// attempt to parse the timeline limit at least // Parse the filter from the query string
var f filter if err := json.Unmarshal([]byte(filterQuery), &filter); err != nil {
err := json.Unmarshal([]byte(filterQuery), &f) return nil, fmt.Errorf("json.Unmarshal: %w", err)
if err == nil && f.Room.Timeline.Limit != nil {
timelineLimit = *f.Room.Timeline.Limit
} }
} else { } else {
// attempt to load the filter ID // Try to load the filter from the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, err return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery) if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil {
if err == nil { util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
timelineLimit = f.Room.Timeline.Limit return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
} else {
filter = *f
} }
} }
} }
filter := gomatrixserverlib.DefaultEventFilter()
filter.Limit = timelineLimit
// TODO: Additional query params: set_presence, filter
logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{ logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{
"user_id": device.UserID, "user_id": device.UserID,
"device_id": device.ID, "device_id": device.ID,
"since": since, "since": since,
"timeout": timeout, "timeout": timeout,
"limit": timelineLimit, "limit": filter.Room.Timeline.Limit,
}) })
return &types.SyncRequest{ return &types.SyncRequest{
@ -96,7 +85,6 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
Filter: filter, // Filter: filter, //
Since: since, // Since: since, //
Timeout: timeout, // Timeout: timeout, //
Limit: timelineLimit, //
Rooms: make(map[string]string), // Populated by the PDU stream Rooms: make(map[string]string), // Populated by the PDU stream
WantFullState: wantFullState, // WantFullState: wantFullState, //
}, nil }, nil

View file

@ -14,9 +14,8 @@ type SyncRequest struct {
Log *logrus.Entry Log *logrus.Entry
Device *userapi.Device Device *userapi.Device
Response *Response Response *Response
Filter gomatrixserverlib.EventFilter Filter gomatrixserverlib.Filter
Since StreamingToken Since StreamingToken
Limit int
Timeout time.Duration Timeout time.Duration
WantFullState bool WantFullState bool

View file

@ -503,3 +503,8 @@ A next_batch token can be used in the v1 messages API
Users receive device_list updates for their own devices Users receive device_list updates for their own devices
m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users
m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users
State is included in the timeline in the initial sync
State from remote users is included in the state in the initial sync
Changes to state are included in an gapped incremental sync
A full_state incremental update returns all state
Can pass a JSON filter as a query parameter