Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/phonehomestats

This commit is contained in:
Till Faelligen 2022-03-02 08:36:49 +01:00
commit 9e14e6afe0
42 changed files with 840 additions and 1030 deletions

View file

@ -63,7 +63,7 @@ jobs:
# Run Complement
- run: |
set -o pipefail &&
go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
shell: bash
name: Run Complement Tests
env:

View file

@ -144,21 +144,23 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
delete(u.Sessions, sessionID)
}
// Challenge returns an HTTP 401 with the supported flows for authenticating
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
return &util.JSONResponse{
Code: 401,
JSON: struct {
type Challenge struct {
Completed []string `json:"completed"`
Flows []userInteractiveFlow `json:"flows"`
Session string `json:"session"`
// TODO: Return any additional `params`
Params map[string]interface{} `json:"params"`
}{
u.Completed,
u.Flows,
sessionID,
make(map[string]interface{}),
}
// Challenge returns an HTTP 401 with the supported flows for authenticating
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
return &util.JSONResponse{
Code: 401,
JSON: Challenge{
Completed: u.Completed,
Flows: u.Flows,
Session: sessionID,
Params: make(map[string]interface{}),
},
}
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
)
// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices
@ -163,6 +164,15 @@ func DeleteDeviceById(
req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device,
deviceID string,
) util.JSONResponse {
var (
deleteOK bool
sessionID string
)
defer func() {
if deleteOK {
sessions.deleteSession(sessionID)
}
}()
ctx := req.Context()
defer req.Body.Close() // nolint:errcheck
bodyBytes, err := ioutil.ReadAll(req.Body)
@ -172,8 +182,29 @@ func DeleteDeviceById(
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
}
}
// check that we know this session, and it matches with the device to delete
s := gjson.GetBytes(bodyBytes, "auth.session").Str
if dev, ok := sessions.getDeviceToDelete(s); ok {
if dev != deviceID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("session & device mismatch"),
}
}
}
if s != "" {
sessionID = s
}
login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device)
if errRes != nil {
switch data := errRes.JSON.(type) {
case auth.Challenge:
sessions.addDeviceToDelete(data.Session, deviceID)
default:
}
return *errRes
}
@ -201,6 +232,8 @@ func DeleteDeviceById(
return jsonerror.InternalServerError()
}
deleteOK = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},

View file

@ -76,6 +76,10 @@ type sessionsDict struct {
sessions map[string][]authtypes.LoginType
params map[string]registerRequest
timer map[string]*time.Timer
// deleteSessionToDeviceID protects requests to DELETE /devices/{deviceID} from being abused.
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
deleteSessionToDeviceID map[string]string
}
// defaultTimeout is the timeout used to clean up sessions
@ -115,6 +119,7 @@ func (d *sessionsDict) deleteSession(sessionID string) {
defer d.Unlock()
delete(d.params, sessionID)
delete(d.sessions, sessionID)
delete(d.deleteSessionToDeviceID, sessionID)
// stop the timer, e.g. because the registration was completed
if t, ok := d.timer[sessionID]; ok {
if !t.Stop() {
@ -132,6 +137,7 @@ func newSessionsDict() *sessionsDict {
sessions: make(map[string][]authtypes.LoginType),
params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer),
deleteSessionToDeviceID: make(map[string]string),
}
}
@ -165,6 +171,20 @@ func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtype
d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
}
func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) {
d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
d.deleteSessionToDeviceID[sessionID] = deviceID
}
func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
d.RLock()
defer d.RUnlock()
deviceID, ok := d.deleteSessionToDeviceID[sessionID]
return deviceID, ok
}
var (
sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)

View file

@ -242,6 +242,7 @@ func TestSessionCleanUp(t *testing.T) {
s.addParams(dummySession, registerRequest{Username: "Testing"})
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
s.addDeviceToDelete(dummySession, "dummyDevice")
s.getCompletedStages(dummySession)
// reset the timer with a lower timeout
s.startTimer(time.Millisecond, dummySession)
@ -249,5 +250,14 @@ func TestSessionCleanUp(t *testing.T) {
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
if _, ok := s.timer[dummySession]; ok {
t.Error("expected timer to be delete")
}
if _, ok := s.sessions[dummySession]; ok {
t.Error("expected session to be delete")
}
if _, ok := s.getDeviceToDelete(dummySession); ok {
t.Error("expected session to device to be delete")
}
})
}

View file

@ -345,11 +345,6 @@ user_api:
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
device_database:
connection_string: file:userapi_devices.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# The length of time that a token issued for a relying party from
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
# is considered to be valid in milliseconds.

View file

@ -21,7 +21,7 @@ type FederationClient interface {
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)

View file

@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
}
func (a *FederationInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, r)
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly)
})
if err != nil {
return res, err

View file

@ -527,21 +527,21 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
type spacesReq struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2946SpacesRequest
SuggestedOnly bool
RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
defer span.Finish()
request := spacesReq{
S: dst,
Req: r,
SuggestedOnly: suggestedOnly,
RoomID: roomID,
}
var response spacesReq

View file

@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req)
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {

7
go.mod
View file

@ -18,6 +18,7 @@ require (
github.com/frankban/quicktest v1.14.0 // indirect
github.com/getsentry/sentry-go v0.12.0
github.com/gologme/log v1.3.0
github.com/google/uuid v1.2.0
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2
github.com/h2non/filetype v1.1.3 // indirect
@ -39,7 +40,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10
@ -61,11 +62,11 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.2
go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a
golang.org/x/crypto v0.0.0-20220214200702-86341886e292
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
golang.org/x/mobile v0.0.0-20220112015953-858099ff7816
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
gopkg.in/h2non/bimg.v1 v1.1.5
gopkg.in/yaml.v2 v2.4.0

14
go.sum
View file

@ -983,8 +983,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 h1:3B0ZEJ5YVVbRKHV7WFgji5z6s262YIVRZvtDotdpbsI=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -1510,8 +1512,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1737,8 +1739,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc=
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs=
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=

View file

@ -10,6 +10,7 @@ const (
FederationEventCacheName = "federation_event"
FederationEventCacheMaxEntries = 256
FederationEventCacheMutable = true // to allow use of Unset only
FederationEventCacheMaxAge = CacheNoMaxAge
)
// FederationCache contains the subset of functions needed for

View file

@ -1,6 +1,8 @@
package caching
import (
"time"
"github.com/matrix-org/dendrite/roomserver/types"
)
@ -16,6 +18,7 @@ const (
RoomInfoCacheName = "roominfo"
RoomInfoCacheMaxEntries = 1024
RoomInfoCacheMutable = true
RoomInfoCacheMaxAge = time.Minute * 5
)
// RoomInfosCache contains the subset of functions needed for

View file

@ -10,6 +10,7 @@ const (
RoomServerRoomIDsCacheName = "roomserver_room_ids"
RoomServerRoomIDsCacheMaxEntries = 1024
RoomServerRoomIDsCacheMutable = false
RoomServerRoomIDsCacheMaxAge = CacheNoMaxAge
)
type RoomServerCaches interface {

View file

@ -6,6 +6,7 @@ const (
RoomVersionCacheName = "room_versions"
RoomVersionCacheMaxEntries = 1024
RoomVersionCacheMutable = false
RoomVersionCacheMaxAge = CacheNoMaxAge
)
// RoomVersionsCache contains the subset of functions needed for

View file

@ -10,6 +10,7 @@ const (
ServerKeyCacheName = "server_key"
ServerKeyCacheMaxEntries = 4096
ServerKeyCacheMutable = true
ServerKeyCacheMaxAge = CacheNoMaxAge
)
// ServerKeyCache contains the subset of functions needed for

View file

@ -0,0 +1,33 @@
package caching
import (
"time"
"github.com/matrix-org/gomatrixserverlib"
)
const (
SpaceSummaryRoomsCacheName = "space_summary_rooms"
SpaceSummaryRoomsCacheMaxEntries = 100
SpaceSummaryRoomsCacheMutable = true
SpaceSummaryRoomsCacheMaxAge = time.Minute * 5
)
type SpaceSummaryRoomsCache interface {
GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool)
StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse)
}
func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) {
val, found := c.SpaceSummaryRooms.Get(roomID)
if found && val != nil {
if resp, ok := val.(gomatrixserverlib.MSC2946SpacesResponse); ok {
return resp, true
}
}
return r, false
}
func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) {
c.SpaceSummaryRooms.Set(roomID, r)
}

View file

@ -1,5 +1,7 @@
package caching
import "time"
// Caches contains a set of references to caches. They may be
// different implementations as long as they satisfy the Cache
// interface.
@ -10,6 +12,7 @@ type Caches struct {
RoomServerRoomIDs Cache // RoomServerNIDsCache
RoomInfos Cache // RoomInfoCache
FederationEvents Cache // FederationEventsCache
SpaceSummaryRooms Cache // SpaceSummaryRoomsCache
}
// Cache is the interface that an implementation must satisfy.
@ -18,3 +21,5 @@ type Cache interface {
Set(key string, value interface{})
Unset(key string)
}
const CacheNoMaxAge = time.Duration(0)

View file

@ -14,6 +14,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomVersionCacheName,
RoomVersionCacheMutable,
RoomVersionCacheMaxEntries,
RoomVersionCacheMaxAge,
enablePrometheus,
)
if err != nil {
@ -23,6 +24,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
ServerKeyCacheName,
ServerKeyCacheMutable,
ServerKeyCacheMaxEntries,
ServerKeyCacheMaxAge,
enablePrometheus,
)
if err != nil {
@ -32,6 +34,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomServerRoomIDsCacheName,
RoomServerRoomIDsCacheMutable,
RoomServerRoomIDsCacheMaxEntries,
RoomServerRoomIDsCacheMaxAge,
enablePrometheus,
)
if err != nil {
@ -41,6 +44,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomInfoCacheName,
RoomInfoCacheMutable,
RoomInfoCacheMaxEntries,
RoomInfoCacheMaxAge,
enablePrometheus,
)
if err != nil {
@ -50,6 +54,17 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
FederationEventCacheName,
FederationEventCacheMutable,
FederationEventCacheMaxEntries,
FederationEventCacheMaxAge,
enablePrometheus,
)
if err != nil {
return nil, err
}
spaceRooms, err := NewInMemoryLRUCachePartition(
SpaceSummaryRoomsCacheName,
SpaceSummaryRoomsCacheMutable,
SpaceSummaryRoomsCacheMaxEntries,
SpaceSummaryRoomsCacheMaxAge,
enablePrometheus,
)
if err != nil {
@ -57,7 +72,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
}
go cacheCleaner(
roomVersions, serverKeys, roomServerRoomIDs,
roomInfos, federationEvents,
roomInfos, federationEvents, spaceRooms,
)
return &Caches{
RoomVersions: roomVersions,
@ -65,6 +80,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
RoomServerRoomIDs: roomServerRoomIDs,
RoomInfos: roomInfos,
FederationEvents: federationEvents,
SpaceSummaryRooms: spaceRooms,
}, nil
}
@ -86,15 +102,22 @@ type InMemoryLRUCachePartition struct {
name string
mutable bool
maxEntries int
maxAge time.Duration
lru *lru.Cache
}
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
type inMemoryLRUCacheEntry struct {
value interface{}
created time.Time
}
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, maxAge time.Duration, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
var err error
cache := InMemoryLRUCachePartition{
name: name,
mutable: mutable,
maxEntries: maxEntries,
maxAge: maxAge,
}
cache.lru, err = lru.New(maxEntries)
if err != nil {
@ -114,11 +137,16 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, ena
func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) {
if !c.mutable {
if peek, ok := c.lru.Peek(key); ok && peek != value {
if peek, ok := c.lru.Peek(key); ok {
if entry, ok := peek.(*inMemoryLRUCacheEntry); ok && entry.value != value {
panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key))
}
}
c.lru.Add(key, value)
}
c.lru.Add(key, &inMemoryLRUCacheEntry{
value: value,
created: time.Now(),
})
}
func (c *InMemoryLRUCachePartition) Unset(key string) {
@ -129,5 +157,20 @@ func (c *InMemoryLRUCachePartition) Unset(key string) {
}
func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) {
return c.lru.Get(key)
v, ok := c.lru.Get(key)
if !ok {
return nil, false
}
entry, ok := v.(*inMemoryLRUCacheEntry)
switch {
case ok && c.maxAge == CacheNoMaxAge:
return entry.value, ok // There's no maximum age policy
case ok && time.Since(entry.created) < c.maxAge:
return entry.value, ok // The value for the key isn't stale
default:
// Either the key was found and it was stale, or the key
// wasn't found at all
c.lru.Remove(key)
return nil, false
}
}

View file

@ -166,8 +166,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
}
// We can't have a self-signing or user-signing key without a master
// key, so make sure we have one of those.
if !hasMasterKey {
// key, so make sure we have one of those. We will also only actually do
// something if any of the specified keys in the request are different
// to what we've got in the database, to avoid generating key change
// notifications unnecessarily.
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
@ -176,18 +178,43 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
return
}
_, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]
}
// If we still can't find a master key for the user then stop the upload.
// This satisfies the "Fails to upload self-signing key without master key" test.
if !hasMasterKey {
if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
res.Error = &api.KeyError{
Err: "No master key was found",
IsMissingParam: true,
}
return
}
}
// Check if anything actually changed compared to what we have in the database.
changed := false
for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
gomatrixserverlib.CrossSigningKeyPurposeMaster,
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
} {
old, gotOld := existingKeys[purpose]
new, gotNew := toStore[purpose]
if gotOld != gotNew {
// A new key purpose has been specified that we didn't know before,
// or one has been removed.
changed = true
break
}
if !bytes.Equal(old, new) {
// One of the existing keys for a purpose we already knew about has
// changed.
changed = true
break
}
}
if !changed {
return
}
// Store the keys.
if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {

View file

@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct {
staleUsers map[string]bool
prevIDsExist func(string, []int) bool
storedKeys []api.DeviceMessage
mu sync.Mutex // protect staleUsers
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
d.mu.Lock()
defer d.mu.Unlock()
var result []string
for userID := range d.staleUsers {
for userID, isStale := range d.staleUsers {
if !isStale {
continue
}
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return nil, err
@ -75,6 +81,8 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
d.mu.Lock()
defer d.mu.Unlock()
d.staleUsers[userID] = isStale
return nil
}
@ -247,3 +255,82 @@ func TestUpdateNoPrevID(t *testing.T) {
}
}
// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
// update is still ongoing.
func TestDebounce(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool {
return true
},
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
fedCh := make(chan *http.Response, 1)
srv := gomatrixserverlib.ServerName("example.com")
userID := "@alice:example.com"
keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
incomingFedReq := make(chan struct{})
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
}
close(incomingFedReq)
return <-fedCh, nil
})
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
// hit this 5 times
var wg sync.WaitGroup
wg.Add(5)
for i := 0; i < 5; i++ {
go func() {
defer wg.Done()
if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
t.Errorf("ManualUpdate: %s", err)
}
}()
}
// wait until the updater hits federation
select {
case <-incomingFedReq:
case <-time.After(time.Second):
t.Fatalf("timed out waiting for updater to hit federation")
}
// user should be marked as stale
if !db.staleUsers[userID] {
t.Errorf("user %s not marked as stale", userID)
}
// now send the response over federation
fedCh <- &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`
{
"user_id": "` + userID + `",
"stream_id": 5,
"devices": [
{
"device_id": "JLAFKJWSCS",
"keys": ` + keyJSON + `,
"device_display_name": "Mobile Phone"
}
]
}
`)),
}
close(fedCh)
// wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
// and should panic with read on a closed channel
wg.Wait()
// user is no longer stale now
if db.staleUsers[userID] {
t.Errorf("user %s is marked as stale", userID)
}
}

View file

@ -269,6 +269,7 @@ type QueryAuthChainResponse struct {
type QuerySharedUsersRequest struct {
UserID string
OtherUserIDs []string
ExcludeRoomIDs []string
IncludeRoomIDs []string
}
@ -313,6 +314,9 @@ type QueryBulkStateContentResponse struct {
type QueryCurrentStateRequest struct {
RoomID string
AllowWildcards bool
// State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching
// state events with that event type.
StateTuples []gomatrixserverlib.StateKeyTuple
}

View file

@ -621,6 +621,18 @@ func (r *Queryer) QueryPublishedRooms(
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
for _, tuple := range req.StateTuples {
if tuple.StateKey == "*" && req.AllowWildcards {
events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType)
if err != nil {
return err
}
for _, e := range events {
res.StateEvents[gomatrixserverlib.StateKeyTuple{
EventType: e.Type(),
StateKey: *e.StateKey(),
}] = e
}
} else {
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
if err != nil {
return err
@ -629,6 +641,7 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
res.StateEvents[tuple] = ev
}
}
}
return nil
}
@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
}
roomIDs = roomIDs[:j]
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
if err != nil {
return err
}

View file

@ -146,13 +146,14 @@ type Database interface {
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error)
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.

View file

@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
`
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
if err != nil {
return nil, err
}

View file

@ -13,7 +13,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
@ -979,6 +978,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
return nil, nil
}
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
// if there are no events with this event type.
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID)
if err != nil {
return nil, err
}
if roomInfo == nil {
return nil, fmt.Errorf("room %s doesn't exist", roomID)
}
// e.g invited rooms
if roomInfo.IsStub {
return nil, nil
}
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
if err == sql.ErrNoRows {
// No rooms have an event of this type, otherwise we'd have an event type NID
return nil, nil
}
if err != nil {
return nil, err
}
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
if err != nil {
return nil, err
}
var eventNIDs []types.EventNID
for _, e := range entries {
if e.EventTypeNID == eventTypeNID {
eventNIDs = append(eventNIDs, e.EventNID)
}
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
// return the events requested
eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil {
return nil, err
}
if len(eventPairs) == 0 {
return nil, nil
}
var result []*gomatrixserverlib.HeaderedEvent
for _, pair := range eventPairs {
ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion)
if err != nil {
return nil, err
}
result = append(result, ev.Headered(roomInfo.RoomVersion))
}
return result, nil
}
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
var membershipState tables.MembershipState
@ -1104,13 +1159,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
return result, nil
}
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
if err != nil {
return nil, err
}
userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
for id, nid := range userNIDsMap {
userNIDs = append(userNIDs, nid)
nidToUserID[nid] = id
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
if err != nil {
return nil, err
}
@ -1120,13 +1185,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}
if len(nidToUserID) != len(userNIDToCount) {
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
}
result := make(map[string]int, len(userNIDToCount))
for nid, count := range userNIDToCount {
result[nidToUserID[nid]] = count

View file

@ -42,7 +42,8 @@ const membershipSchema = `
`
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
for _, v := range roomNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
for _, v := range userNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
rows, err = txn.QueryContext(ctx, query, params...)
} else {
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
rows, err = s.db.QueryContext(ctx, query, params...)
}
if err != nil {
return nil, err

View file

@ -127,9 +127,8 @@ type Membership interface {
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)

View file

@ -18,17 +18,19 @@ package msc2946
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
chttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
fs "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
@ -40,41 +42,26 @@ import (
const (
ConstCreateEventContentKey = "type"
ConstCreateEventContentValueSpace = "m.space"
ConstSpaceChildEventType = "m.space.child"
ConstSpaceParentEventType = "m.space.parent"
)
// Defaults sets the request defaults
func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) {
r.Limit = 2000
r.MaxRoomsPerSpace = -1
type MSC2946ClientResponse struct {
Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"`
NextBatch string `json:"next_batch,omitempty"`
}
// Enable this MSC
func Enable(
base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache,
) error {
db, err := NewDatabase(&base.Cfg.MSCs.Database)
if err != nil {
return fmt.Errorf("cannot enable MSC2946: %w", err)
}
hooks.Enable()
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
hookErr := db.StoreReference(context.Background(), he)
if hookErr != nil {
util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
"failed to StoreReference",
)
}
})
clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName))
base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces",
httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)),
).Methods(http.MethodPost, http.MethodOptions)
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI(
fedAPI := httputil.MakeExternalAPI(
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), base.Cfg.Global.ServerName, keyRing,
@ -88,252 +75,308 @@ func Enable(
return util.ErrorResponse(err)
}
roomID := params["roomID"]
return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName)
return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, base.Cfg.Global.ServerName)
},
)).Methods(http.MethodPost, http.MethodOptions)
)
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
return nil
}
func federatedSpacesHandler(
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database,
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string,
cache caching.SpaceSummaryRoomsCache,
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
thisServer gomatrixserverlib.ServerName,
) util.JSONResponse {
inMemoryBatchCache := make(map[string]set)
var r gomatrixserverlib.MSC2946SpacesRequest
Defaults(&r)
if err := json.Unmarshal(fedReq.Content(), &r); err != nil {
u, err := url.Parse(fedReq.RequestURI())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
Code: 400,
JSON: jsonerror.InvalidParam("bad request uri"),
}
}
w := walker{
req: &r,
rootRoomID: roomID,
serverName: fedReq.Origin(),
thisServer: thisServer,
ctx: ctx,
cache: cache,
suggestedOnly: u.Query().Get("suggested_only") == "true",
limit: 1000,
// The main difference is that it does not recurse into spaces and does not support pagination.
// This is somewhat equivalent to a Client-Server request with a max_depth=1.
maxDepth: 1,
db: db,
rsAPI: rsAPI,
fsAPI: fsAPI,
inMemoryBatchCache: inMemoryBatchCache,
}
res := w.walk()
return util.JSONResponse{
Code: 200,
JSON: res,
// inline cache as we don't have pagination in federation mode
paginationCache: make(map[string]paginationInfo),
}
return w.walk()
}
func spacesHandler(
db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
rsAPI roomserver.RoomserverInternalAPI,
fsAPI fs.FederationInternalAPI,
cache caching.SpaceSummaryRoomsCache,
thisServer gomatrixserverlib.ServerName,
) func(*http.Request, *userapi.Device) util.JSONResponse {
// declared outside the returned handler so it persists between calls
// TODO: clear based on... time?
paginationCache := make(map[string]paginationInfo)
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
inMemoryBatchCache := make(map[string]set)
// Extract the room ID from the request. Sanity check request data.
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID := params["roomID"]
var r gomatrixserverlib.MSC2946SpacesRequest
Defaults(&r)
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr
}
w := walker{
req: &r,
suggestedOnly: req.URL.Query().Get("suggested_only") == "true",
limit: parseInt(req.URL.Query().Get("limit"), 1000),
maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1),
paginationToken: req.URL.Query().Get("from"),
rootRoomID: roomID,
caller: device,
thisServer: thisServer,
ctx: req.Context(),
cache: cache,
db: db,
rsAPI: rsAPI,
fsAPI: fsAPI,
inMemoryBatchCache: inMemoryBatchCache,
paginationCache: paginationCache,
}
res := w.walk()
return util.JSONResponse{
Code: 200,
JSON: res,
return w.walk()
}
}
type paginationInfo struct {
processed set
unvisited []roomVisit
}
type walker struct {
req *gomatrixserverlib.MSC2946SpacesRequest
rootRoomID string
caller *userapi.Device
serverName gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName
db Database
rsAPI roomserver.RoomserverInternalAPI
fsAPI fs.FederationInternalAPI
ctx context.Context
cache caching.SpaceSummaryRoomsCache
suggestedOnly bool
limit int
maxDepth int
paginationToken string
// user ID|device ID|batch_num => event/room IDs sent to client
inMemoryBatchCache map[string]set
paginationCache map[string]paginationInfo
mu sync.Mutex
}
func (w *walker) roomIsExcluded(roomID string) bool {
for _, exclRoom := range w.req.ExcludeRooms {
if exclRoom == roomID {
return true
func (w *walker) newPaginationCache() (string, paginationInfo) {
p := paginationInfo{
processed: make(set),
unvisited: nil,
}
}
return false
tok := uuid.NewString()
return tok, p
}
func (w *walker) callerID() string {
func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo {
w.mu.Lock()
defer w.mu.Unlock()
p := w.paginationCache[paginationToken]
return &p
}
func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) {
w.mu.Lock()
defer w.mu.Unlock()
w.paginationCache[paginationToken] = cache
}
type roomVisit struct {
roomID string
depth int
vias []string // vias to query this room by
}
func (w *walker) walk() util.JSONResponse {
if !w.authorised(w.rootRoomID) {
if w.caller != nil {
return w.caller.UserID + "|" + w.caller.ID
// CS API format
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("room is unknown/forbidden"),
}
} else {
// SS API format
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("room is unknown/forbidden"),
}
}
return string(w.serverName)
}
func (w *walker) alreadySent(id string) bool {
w.mu.Lock()
defer w.mu.Unlock()
m, ok := w.inMemoryBatchCache[w.callerID()]
if !ok {
return false
}
return m[id]
}
func (w *walker) markSent(id string) {
w.mu.Lock()
defer w.mu.Unlock()
m := w.inMemoryBatchCache[w.callerID()]
if m == nil {
m = make(set)
}
m[id] = true
w.inMemoryBatchCache[w.callerID()] = m
}
func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse {
var res gomatrixserverlib.MSC2946SpacesResponse
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
unvisited := []string{w.rootRoomID}
processed := make(set)
for len(unvisited) > 0 {
roomID := unvisited[0]
unvisited = unvisited[1:]
// If this room has already been processed, skip. NB: do not remember this between calls
if processed[roomID] || roomID == "" {
continue
}
// Mark this room as processed.
processed[roomID] = true
// Collect rooms/events to send back (either locally or fetched via federation)
var discoveredRooms []gomatrixserverlib.MSC2946Room
var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent
// If we know about this room and the caller is authorised (joined/world_readable) then pull
// events locally
if w.roomExists(roomID) && w.authorised(roomID) {
// Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get
// all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`)
// this room. This requires servers to store reverse lookups.
events, err := w.references(roomID)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room")
var cache *paginationInfo
if w.paginationToken != "" {
cache = w.loadPaginationCache(w.paginationToken)
if cache == nil {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue("invalid from"),
}
}
} else {
tok, c := w.newPaginationCache()
cache = &c
w.paginationToken = tok
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
c.unvisited = append(c.unvisited, roomVisit{
roomID: w.rootRoomID,
depth: 0,
})
}
processed := cache.processed
unvisited := cache.unvisited
// Depth first -> stack data structure
for len(unvisited) > 0 {
if len(discoveredRooms) >= w.limit {
break
}
// pop the stack
rv := unvisited[len(unvisited)-1]
unvisited = unvisited[:len(unvisited)-1]
// If this room has already been processed, skip.
// If this room exceeds the specified depth, skip.
if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) {
continue
}
discoveredEvents = events
pubRoom := w.publicRoomsChunk(roomID)
roomType := ""
create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "")
// Mark this room as processed.
processed.set(rv.roomID)
// if this room is not a space room, skip.
var roomType string
create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "")
if create != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
}
// Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`.
// Collect rooms/events to send back (either locally or fetched via federation)
var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent
// If we know about this room and the caller is authorised (joined/world_readable) then pull
// events locally
if w.roomExists(rv.roomID) && w.authorised(rv.roomID) {
// Get all `m.space.child` state events for this room
events, err := w.childReferences(rv.roomID)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room")
continue
}
discoveredChildEvents = events
pubRoom := w.publicRoomsChunk(rv.roomID)
discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{
PublicRoom: *pubRoom,
NumRefs: len(discoveredEvents),
RoomType: roomType,
ChildrenState: events,
})
} else {
// attempt to query this room over federation, as either we've never heard of it before
// or we've left it and hence are not authorised (but info may be exposed regardless)
fedRes, err := w.federatedRoomInfo(roomID)
fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces")
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces")
continue
}
if fedRes != nil {
discoveredRooms = fedRes.Rooms
discoveredEvents = fedRes.Events
discoveredChildEvents = fedRes.Room.ChildrenState
discoveredRooms = append(discoveredRooms, fedRes.Room)
if len(fedRes.Children) > 0 {
discoveredRooms = append(discoveredRooms, fedRes.Children...)
}
// mark this room as a space room as the federated server responded.
// we need to do this so we add the children of this room to the unvisited stack
// as these children may be rooms we do know about.
roomType = ConstCreateEventContentValueSpace
}
}
// If this room has not ever been in `rooms` (across multiple requests), send it now
for _, room := range discoveredRooms {
if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) {
res.Rooms = append(res.Rooms, room)
w.markSent(room.RoomID)
}
}
uniqueRooms := make(set)
// If this is the root room from the original request, insert all these events into `events` if
// they haven't been added before (across multiple requests).
if w.rootRoomID == roomID {
for _, ev := range discoveredEvents {
if !w.alreadySent(eventKey(&ev)) {
res.Events = append(res.Events, ev)
uniqueRooms[ev.RoomID] = true
uniqueRooms[spaceTargetStripped(&ev)] = true
w.markSent(eventKey(&ev))
}
}
} else {
// Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either
// are exceeded, stop adding events. If the event has already been added, do not add it again.
numAdded := 0
for _, ev := range discoveredEvents {
if w.req.Limit > 0 && len(res.Events) >= w.req.Limit {
break
}
if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace {
break
}
if w.alreadySent(eventKey(&ev)) {
// don't walk the children
// if the parent is not a space room
if roomType != ConstCreateEventContentValueSpace {
continue
}
// Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still
// want to catch arrows which point to excluded rooms.
if w.roomIsExcluded(ev.RoomID) {
continue
}
res.Events = append(res.Events, ev)
uniqueRooms[ev.RoomID] = true
uniqueRooms[spaceTargetStripped(&ev)] = true
w.markSent(eventKey(&ev))
// we don't distinguish between child state events and parent state events for the purposes of
// max_rooms_per_space, maybe we should?
numAdded++
}
}
// For each referenced room ID in the events being returned to the caller (both parent and child)
// For each referenced room ID in the child events being returned to the caller
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
for roomID := range uniqueRooms {
unvisited = append(unvisited, roomID)
// We need to invert the order here because the child events are lo->hi on the timestamp,
// so we need to ensure we pop in the same lo->hi order, which won't be the case if we
// insert the highest timestamp last in a stack.
for i := len(discoveredChildEvents) - 1; i >= 0; i-- {
spaceContent := struct {
Via []string `json:"via"`
}{}
ev := discoveredChildEvents[i]
_ = json.Unmarshal(ev.Content, &spaceContent)
unvisited = append(unvisited, roomVisit{
roomID: ev.StateKey,
depth: rv.depth + 1,
vias: spaceContent.Via,
})
}
}
return &res
if len(unvisited) > 0 {
// we still have more rooms so we need to send back a pagination token,
// we probably hit a room limit
cache.processed = processed
cache.unvisited = unvisited
w.storePaginationCache(w.paginationToken, *cache)
} else {
// clear the pagination token so we don't send it back to the client
// Note we do NOT nuke the cache just in case this response is lost
// and the client retries it.
w.paginationToken = ""
}
if w.caller != nil {
// return CS API format
return util.JSONResponse{
Code: 200,
JSON: MSC2946ClientResponse{
Rooms: discoveredRooms,
NextBatch: w.paginationToken,
},
}
}
// return SS API format
// the first discovered room will be the room asked for, and subsequent ones the depth=1 children
if len(discoveredRooms) == 0 {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("room is unknown/forbidden"),
}
}
return util.JSONResponse{
Code: 200,
JSON: gomatrixserverlib.MSC2946SpacesResponse{
Room: discoveredRooms[0],
Children: discoveredRooms[1:],
},
}
}
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent {
@ -366,46 +409,41 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom {
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
// unsuccessful.
func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
// only do federated requests for client requests
if w.caller == nil {
return nil, nil
}
// extract events which point to this room ID and extract their vias
events, err := w.db.References(w.ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get References events: %w", err)
resp, ok := w.cache.GetSpaceSummary(roomID)
if ok {
util.GetLogger(w.ctx).Debugf("Returning cached response for %s", roomID)
return &resp, nil
}
vias := make(set)
for _, ev := range events {
if ev.StateKeyEquals(roomID) {
// event points at this room, extract vias
content := struct {
Vias []string `json:"via"`
}{}
if err = json.Unmarshal(ev.Content(), &content); err != nil {
continue // silently ignore corrupted state events
}
for _, v := range content.Vias {
vias[v] = true
}
}
}
util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias)
util.GetLogger(w.ctx).Debugf("Querying %s via %+v", roomID, vias)
ctx := context.Background()
// query more of the spaces graph using these servers
for serverName := range vias {
for _, serverName := range vias {
if serverName == string(w.thisServer) {
continue
}
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{
Limit: w.req.Limit,
MaxRoomsPerSpace: w.req.MaxRoomsPerSpace,
})
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
continue
}
// ensure nil slices are empty as we send this to the client sometimes
if res.Room.ChildrenState == nil {
res.Room.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{}
}
for i := 0; i < len(res.Children); i++ {
child := res.Children[i]
if child.ChildrenState == nil {
child.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{}
}
res.Children[i] = child
}
w.cache.StoreSpaceSummary(roomID, res)
return &res, nil
}
return nil, nil
@ -501,7 +539,7 @@ func (w *walker) authorisedUser(roomID string) bool {
hisVisEv := queryRes.StateEvents[hisVisTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == gomatrixserverlib.Join {
if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite {
return true
}
}
@ -514,29 +552,73 @@ func (w *walker) authorisedUser(roomID string) bool {
return false
}
// references returns all references pointing to or from this room.
func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
events, err := w.db.References(w.ctx, roomID)
// references returns all child references pointing to or from this room.
func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
createTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomCreate,
StateKey: "",
}
var res roomserver.QueryCurrentStateResponse
err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
AllowWildcards: true,
StateTuples: []gomatrixserverlib.StateKeyTuple{
createTuple, {
EventType: ConstSpaceChildEventType,
StateKey: "*",
},
},
}, &res)
if err != nil {
return nil, err
}
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events))
for _, ev := range events {
// don't return any child refs if the room is not a space room
if res.StateEvents[createTuple] != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
if roomType != ConstCreateEventContentValueSpace {
return []gomatrixserverlib.MSC2946StrippedEvent{}, nil
}
}
delete(res.StateEvents, createTuple)
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents))
for _, ev := range res.StateEvents {
content := gjson.ParseBytes(ev.Content())
// only return events that have a `via` key as per MSC1772
// else we'll incorrectly walk redacted events (as the link
// is in the state_key)
if gjson.GetBytes(ev.Content(), "via").Exists() {
if content.Get("via").Exists() {
strip := stripped(ev.Event)
if strip == nil {
continue
}
// if suggested only and this child isn't suggested, skip it.
// if suggested only = false we include everything so don't need to check the content.
if w.suggestedOnly && !content.Get("suggested").Bool() {
continue
}
el = append(el, *strip)
}
}
// sort by origin_server_ts as per MSC2946
sort.Slice(el, func(i, j int) bool {
return el[i].OriginServerTS < el[j].OriginServerTS
})
return el, nil
}
type set map[string]bool
type set map[string]struct{}
func (s set) set(val string) {
s[val] = struct{}{}
}
func (s set) isSet(val string) bool {
_, ok := s[val]
return ok
}
func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent {
if ev.StateKey() == nil {
@ -548,6 +630,7 @@ func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEve
Content: ev.Content(),
Sender: ev.Sender(),
RoomID: ev.RoomID(),
OriginServerTS: ev.OriginServerTS(),
}
}
@ -567,3 +650,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string {
}
return ""
}
func parseInt(intstr string, defaultVal int) int {
i, err := strconv.ParseInt(intstr, 10, 32)
if err != nil {
return defaultVal
}
return int(i)
}

View file

@ -1,464 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msc2946_test
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
var (
client = &http.Client{
Timeout: 10 * time.Second,
}
roomVer = gomatrixserverlib.RoomVersionV6
)
// Basic sanity check of MSC2946 logic. Tests a single room with a few state events
// and a bit of recursion to subspaces. Makes a graph like:
// Root
// ____|_____
// | | |
// R1 R2 S1
// |_________
// | | |
// R3 R4 S2
// | <-- this link is just a parent, not a child
// R5
//
// Alice is not joined to R4, but R4 is "world_readable".
func TestMSC2946(t *testing.T) {
alice := "@alice:localhost"
// give access token to alice
nopUserAPI := &testUserAPI{
accessTokens: make(map[string]userapi.Device),
}
nopUserAPI.accessTokens["alice"] = userapi.Device{
AccessToken: "alice",
DisplayName: "Alice",
UserID: alice,
}
rootSpace := "!rootspace:localhost"
subSpaceS1 := "!subspaceS1:localhost"
subSpaceS2 := "!subspaceS2:localhost"
room1 := "!room1:localhost"
room2 := "!room2:localhost"
room3 := "!room3:localhost"
room4 := "!room4:localhost"
empty := ""
room5 := "!room5:localhost"
allRooms := []string{
rootSpace, subSpaceS1, subSpaceS2,
room1, room2, room3, room4, room5,
}
rootToR1 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room1,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
rootToR2 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
rootToS1 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &subSpaceS1,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToR3 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room3,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToR4 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room4,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToS2 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &subSpaceS2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
// This is a parent link only
s2ToR5 := mustCreateEvent(t, fledglingEvent{
RoomID: room5,
Sender: alice,
Type: msc2946.ConstSpaceParentEventType,
StateKey: &subSpaceS2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
// history visibility for R4
r4HisVis := mustCreateEvent(t, fledglingEvent{
RoomID: room4,
Sender: "@someone:localhost",
Type: gomatrixserverlib.MRoomHistoryVisibility,
StateKey: &empty,
Content: map[string]interface{}{
"history_visibility": "world_readable",
},
})
var joinEvents []*gomatrixserverlib.HeaderedEvent
for _, roomID := range allRooms {
if roomID == room4 {
continue // not joined to that room
}
joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{
RoomID: roomID,
Sender: alice,
StateKey: &alice,
Type: gomatrixserverlib.MRoomMember,
Content: map[string]interface{}{
"membership": "join",
},
}))
}
roomNameTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.name",
StateKey: "",
}
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.history_visibility",
StateKey: "",
}
nopRsAPI := &testRoomserverAPI{
joinEvents: joinEvents,
events: map[string]*gomatrixserverlib.HeaderedEvent{
rootToR1.EventID(): rootToR1,
rootToR2.EventID(): rootToR2,
rootToS1.EventID(): rootToS1,
s1ToR3.EventID(): s1ToR3,
s1ToR4.EventID(): s1ToR4,
s1ToS2.EventID(): s1ToS2,
s2ToR5.EventID(): s2ToR5,
r4HisVis.EventID(): r4HisVis,
},
pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{
rootSpace: {
roomNameTuple: "Root",
hisVisTuple: "shared",
},
subSpaceS1: {
roomNameTuple: "Sub-Space 1",
hisVisTuple: "joined",
},
subSpaceS2: {
roomNameTuple: "Sub-Space 2",
hisVisTuple: "shared",
},
room1: {
hisVisTuple: "joined",
},
room2: {
hisVisTuple: "joined",
},
room3: {
hisVisTuple: "joined",
},
room4: {
hisVisTuple: "world_readable",
},
room5: {
hisVisTuple: "joined",
},
},
}
allEvents := []*gomatrixserverlib.HeaderedEvent{
rootToR1, rootToR2, rootToS1,
s1ToR3, s1ToR4, s1ToS2,
s2ToR5, r4HisVis,
}
allEvents = append(allEvents, joinEvents...)
router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents)
cancel := runServer(t, router)
defer cancel()
t.Run("returns no events for unknown rooms", func(t *testing.T) {
res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{}))
if len(res.Events) > 0 {
t.Errorf("got %d events, want 0", len(res.Events))
}
if len(res.Rooms) > 0 {
t.Errorf("got %d rooms, want 0", len(res.Rooms))
}
})
t.Run("returns the entire graph", func(t *testing.T) {
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
if len(res.Events) != 7 {
t.Errorf("got %d events, want 7", len(res.Events))
}
if len(res.Rooms) != len(allRooms) {
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms))
}
})
t.Run("can update the graph", func(t *testing.T) {
// remove R3 from the graph
rmS1ToR3 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room3,
Content: map[string]interface{}{}, // redacted
})
nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3
hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3)
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
if len(res.Events) != 6 { // one less since we don't return redacted events
t.Errorf("got %d events, want 6", len(res.Events))
}
if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1)
}
})
}
func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest {
t.Helper()
b, err := json.Marshal(jsonBody)
if err != nil {
t.Fatalf("Failed to marshal request: %s", err)
}
var r gomatrixserverlib.MSC2946SpacesRequest
if err := json.Unmarshal(b, &r); err != nil {
t.Fatalf("Failed to unmarshal request: %s", err)
}
return &r
}
func runServer(t *testing.T, router *mux.Router) func() {
t.Helper()
externalServ := &http.Server{
Addr: string(":8010"),
WriteTimeout: 60 * time.Second,
Handler: router,
}
go func() {
externalServ.ListenAndServe()
}()
// wait to listen on the port
time.Sleep(500 * time.Millisecond)
return func() {
externalServ.Shutdown(context.TODO())
}
}
func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse {
t.Helper()
var r gomatrixserverlib.MSC2946SpacesRequest
msc2946.Defaults(&r)
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal request: %s", err)
}
httpReq, err := http.NewRequest(
"POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces",
bytes.NewBuffer(data),
)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if err != nil {
t.Fatalf("failed to prepare request: %s", err)
}
res, err := client.Do(httpReq)
if err != nil {
t.Fatalf("failed to do request: %s", err)
}
if res.StatusCode != expectCode {
body, _ := ioutil.ReadAll(res.Body)
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
}
if res.StatusCode == 200 {
var result gomatrixserverlib.MSC2946SpacesResponse
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("response 200 OK but failed to read response body: %s", err)
}
t.Logf("Body: %s", string(body))
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
}
return &result
}
return nil
}
type testUserAPI struct {
userapi.UserInternalAPITrace
accessTokens map[string]userapi.Device
}
func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
dev, ok := u.accessTokens[req.AccessToken]
if !ok {
res.Err = "unknown token"
return nil
}
res.Device = &dev
return nil
}
type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here.
// We'll override the functions we care about.
roomserver.RoomserverInternalAPITrace
joinEvents []*gomatrixserverlib.HeaderedEvent
events map[string]*gomatrixserverlib.HeaderedEvent
pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string
}
func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error {
res.IsInRoom = true
res.RoomExists = true
return nil
}
func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error {
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
for _, roomID := range req.RoomIDs {
pubRoomData, ok := r.pubRoomState[roomID]
if ok {
res.Rooms[roomID] = pubRoomData
}
}
return nil
}
func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error {
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
checkEvent := func(he *gomatrixserverlib.HeaderedEvent) {
if he.RoomID() != req.RoomID {
return
}
if he.StateKey() == nil {
return
}
tuple := gomatrixserverlib.StateKeyTuple{
EventType: he.Type(),
StateKey: *he.StateKey(),
}
for _, t := range req.StateTuples {
if t == tuple {
res.StateEvents[t] = he
}
}
}
for _, he := range r.joinEvents {
checkEvent(he)
}
for _, he := range r.events {
checkEvent(he)
}
return nil
}
func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
t.Helper()
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = "localhost"
cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db"
cfg.MSCs.MSCs = []string{"msc2946"}
base := &base.BaseDendrite{
Cfg: cfg,
PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
}
err := msc2946.Enable(base, rsAPI, userAPI, nil, nil)
if err != nil {
t.Fatalf("failed to enable MSC2946: %s", err)
}
for _, ev := range events {
hooks.Run(hooks.KindNewEventPersisted, ev)
}
return base.PublicClientAPIMux
}
type fledglingEvent struct {
Type string
StateKey *string
Content interface{}
Sender string
RoomID string
}
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
t.Helper()
seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.EventBuilder{
Sender: ev.Sender,
Depth: 999,
Type: ev.Type,
StateKey: ev.StateKey,
RoomID: ev.RoomID,
}
err := eb.SetContent(ev.Content)
if err != nil {
t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
}
// make sure the origin_server_ts changes so we can test recency
time.Sleep(1 * time.Millisecond)
signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
if err != nil {
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
}
h := signedEvent.Headered(roomVer)
return h
}

View file

@ -1,182 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msc2946
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
var (
relTypes = map[string]int{
ConstSpaceChildEventType: 1,
ConstSpaceParentEventType: 2,
}
)
type Database interface {
// StoreReference persists a child or parent space mapping.
StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error
// References returns all events which have the given roomID as a parent or child space.
References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error)
}
type DB struct {
db *sql.DB
writer sqlutil.Writer
insertEdgeStmt *sql.Stmt
selectEdgesStmt *sql.Stmt
}
// NewDatabase loads the database for msc2836
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if dbOpts.ConnectionString.IsPostgres() {
return newPostgresDatabase(dbOpts)
}
return newSQLiteDatabase(dbOpts)
}
func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewDummyWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2946_edges (
room_version TEXT NOT NULL,
-- the room ID of the event, the source of the arrow
source_room_id TEXT NOT NULL,
-- the target room ID, the arrow destination
dest_room_id TEXT NOT NULL,
-- the kind of relation, either child or parent (1,2)
rel_type SMALLINT NOT NULL,
event_json TEXT NOT NULL,
CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type)
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5
`); err != nil {
return nil, err
}
if d.selectEdgesStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2
`); err != nil {
return nil, err
}
return &d, err
}
func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2946_edges (
room_version TEXT NOT NULL,
-- the room ID of the event, the source of the arrow
source_room_id TEXT NOT NULL,
-- the target room ID, the arrow destination
dest_room_id TEXT NOT NULL,
-- the kind of relation, either child or parent (1,2)
rel_type SMALLINT NOT NULL,
event_json TEXT NOT NULL,
UNIQUE (source_room_id, dest_room_id, rel_type)
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5
`); err != nil {
return nil, err
}
if d.selectEdgesStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2
`); err != nil {
return nil, err
}
return &d, err
}
func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error {
target := SpaceTarget(he)
if target == "" {
return nil // malformed event
}
relType := relTypes[he.Type()]
_, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON())
return err
}
func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) {
rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close References")
refs := make([]*gomatrixserverlib.HeaderedEvent, 0)
for rows.Next() {
var roomVer string
var jsonBytes []byte
if err := rows.Scan(&roomVer, &jsonBytes); err != nil {
return nil, err
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer))
if err != nil {
return nil, err
}
he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer))
refs = append(refs, he)
}
return refs, nil
}
// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent
// depending on the event type.
func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string {
if he.StateKey() == nil {
return "" // no-op
}
switch he.Type() {
case ConstSpaceParentEventType:
return *he.StateKey()
case ConstSpaceChildEventType:
return *he.StateKey()
}
return ""
}

View file

@ -42,7 +42,7 @@ func EnableMSC(base *base.BaseDendrite, monolith *setup.Monolith, msc string) er
case "msc2836":
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing)
case "msc2946":
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing)
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, base.Caches)
case "msc2444": // enabled inside federationapi
case "msc2753": // enabled inside clientapi
default:

View file

@ -82,7 +82,16 @@ func DeviceListCatchup(
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return to, hasNew, nil
}
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
// Work out which user IDs we care about — that includes those in the original request,
// the response from QueryKeyChanges (which includes ALL users who have changed keys)
// as well as every user who has a join or leave event in the current sync response. We
// will request information about which rooms these users are joined to, so that we can
// see if we still share any rooms with them.
joinUserIDs, leaveUserIDs := membershipEvents(res)
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
var sharedUsersMap map[string]int
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf(
@ -100,9 +109,8 @@ func DeviceListCatchup(
userSet[userID] = true
}
}
// if the response has any join/leave events, add them now.
// Finally, add in users who have joined or left.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
joinUserIDs, leaveUserIDs := membershipEvents(res)
for _, userID := range joinUserIDs {
if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
@ -214,6 +222,7 @@ func filterSharedUsers(
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: userID,
OtherUserIDs: usersWithChangedKeys,
}, &sharedUsersRes)
if err != nil {
// default to all users so we do needless queries rather than miss some important device update

View file

@ -44,7 +44,7 @@ func Context(
syncDB storage.Database,
roomID, eventID string,
) util.JSONResponse {
filter, err := parseContextParams(req)
filter, err := parseRoomEventFilter(req)
if err != nil {
errMsg := ""
switch err.(type) {
@ -164,7 +164,7 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter
return newState
}
func parseContextParams(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {
func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {
// Default room filter
filter := &gomatrixserverlib.RoomEventFilter{Limit: 10}

View file

@ -55,13 +55,13 @@ func Test_parseContextParams(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotFilter, err := parseContextParams(tt.req)
gotFilter, err := parseRoomEventFilter(tt.req)
if (err != nil) != tt.wantErr {
t.Errorf("parseContextParams() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("parseRoomEventFilter() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotFilter, tt.wantFilter) {
t.Errorf("parseContextParams() gotFilter = %v, want %v", gotFilter, tt.wantFilter)
t.Errorf("parseRoomEventFilter() gotFilter = %v, want %v", gotFilter, tt.wantFilter)
}
})
}

View file

@ -19,7 +19,6 @@ import (
"fmt"
"net/http"
"sort"
"strconv"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api"
@ -45,8 +44,8 @@ type messagesReq struct {
fromStream *types.StreamingToken
device *userapi.Device
wasToProvided bool
limit int
backwardOrdering bool
filter *gomatrixserverlib.RoomEventFilter
}
type messagesResp struct {
@ -54,10 +53,9 @@ type messagesResp struct {
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
State []gomatrixserverlib.ClientEvent `json:"state"`
}
const defaultMessagesLimit = 10
// OnIncomingMessagesRequest implements the /messages endpoint from the
// client-server API.
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
@ -83,6 +81,14 @@ func OnIncomingMessagesRequest(
}
}
filter, err := parseRoomEventFilter(req)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("unable to parse filter"),
}
}
// Extract parameters from the request's URL.
// Pagination tokens.
var fromStream *types.StreamingToken
@ -143,18 +149,6 @@ func OnIncomingMessagesRequest(
wasToProvided = false
}
// Maximum number of events to return; defaults to 10.
limit := defaultMessagesLimit
if len(req.URL.Query().Get("limit")) > 0 {
limit, err = strconv.Atoi(req.URL.Query().Get("limit"))
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()),
}
}
}
// TODO: Implement filtering (#587)
// Check the room ID's format.
@ -176,7 +170,7 @@ func OnIncomingMessagesRequest(
to: &to,
fromStream: fromStream,
wasToProvided: wasToProvided,
limit: limit,
filter: filter,
backwardOrdering: backwardOrdering,
device: device,
}
@ -187,10 +181,27 @@ func OnIncomingMessagesRequest(
return jsonerror.InternalServerError()
}
// at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set
state := []gomatrixserverlib.ClientEvent{}
if filter.LazyLoadMembers {
memberShipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent)
for _, evt := range clientEvents {
memberShip, err := db.GetStateEvent(req.Context(), roomID, gomatrixserverlib.MRoomMember, evt.Sender)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("failed to get membership event for user")
continue
}
memberShipToUser[evt.Sender] = memberShip
}
for _, evt := range memberShipToUser {
state = append(state, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll))
}
}
util.GetLogger(req.Context()).WithFields(logrus.Fields{
"from": from.String(),
"to": to.String(),
"limit": limit,
"limit": filter.Limit,
"backwards": backwardOrdering,
"return_start": start.String(),
"return_end": end.String(),
@ -200,6 +211,7 @@ func OnIncomingMessagesRequest(
Chunk: clientEvents,
Start: start.String(),
End: end.String(),
State: state,
}
if emptyFromSupplied {
res.StartStream = fromStream.String()
@ -234,19 +246,18 @@ func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start,
end types.TopologyToken, err error,
) {
eventFilter := gomatrixserverlib.DefaultRoomEventFilter()
eventFilter.Limit = r.limit
eventFilter := r.filter
// Retrieve the events from the local database.
var streamEvents []types.StreamEvent
if r.fromStream != nil {
toStream := r.to.StreamToken()
streamEvents, err = r.db.GetEventsInStreamingRange(
r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering,
r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering,
)
} else {
streamEvents, err = r.db.GetEventsInTopologicalRange(
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
)
}
if err != nil {
@ -434,7 +445,7 @@ func (r *messagesReq) handleEmptyEventsSlice() (
// Check if we have backward extremities for this room.
if len(backwardExtremities) > 0 {
// If so, retrieve as much events as needed through backfilling.
events, err = r.backfill(r.roomID, backwardExtremities, r.limit)
events, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit)
if err != nil {
return
}
@ -456,7 +467,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
events []*gomatrixserverlib.HeaderedEvent, err error,
) {
// Check if we have enough events.
isSetLargeEnough := len(streamEvents) >= r.limit
isSetLargeEnough := len(streamEvents) >= r.filter.Limit
if !isSetLargeEnough {
// it might be fine we don't have up to 'limit' events, let's find out
if r.backwardOrdering {
@ -483,7 +494,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering {
var pdus []*gomatrixserverlib.HeaderedEvent
// Only ask the remote server for enough events to reach the limit.
pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents))
pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents))
if err != nil {
return
}

View file

@ -96,7 +96,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
}
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
lastPos := streamPos
var lastPos types.StreamPosition
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)

View file

@ -101,7 +101,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
lastPos := streamPos
var lastPos types.StreamPosition
params := make([]interface{}, len(roomIDs)+1)
params[0] = streamPos
for k, v := range roomIDs {

View file

@ -63,7 +63,6 @@ func (p *ReceiptStreamProvider) IncrementalSync(
if existing, ok := req.Response.Rooms.Join[roomID]; ok {
jr = existing
}
var ok bool
ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MReceipt,
@ -71,8 +70,8 @@ func (p *ReceiptStreamProvider) IncrementalSync(
}
content := make(map[string]eduAPI.ReceiptMRead)
for _, receipt := range receipts {
var read eduAPI.ReceiptMRead
if read, ok = content[receipt.EventID]; !ok {
read, ok := content[receipt.EventID]
if !ok {
read = eduAPI.ReceiptMRead{
User: make(map[string]eduAPI.ReceiptTS),
}

View file

@ -69,7 +69,7 @@ func (s *Streams) Latest(ctx context.Context) types.StreamingToken {
return types.StreamingToken{
PDUPosition: s.PDUStreamProvider.LatestPosition(ctx),
TypingPosition: s.TypingStreamProvider.LatestPosition(ctx),
ReceiptPosition: s.PDUStreamProvider.LatestPosition(ctx),
ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx),
InvitePosition: s.InviteStreamProvider.LatestPosition(ctx),
SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx),
AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx),

View file

@ -597,6 +597,7 @@ Device list doesn't change if remote server is down
/context/ on non world readable room does not work
/context/ returns correct number of events
/context/ with lazy_load_members filter works
GET /rooms/:room_id/messages lazy loads members correctly
Can query remote device keys using POST after notification
Device deletion propagates over federation
Get left notifs in sync and /keys/changes when other user leaves
@ -604,4 +605,6 @@ Remote banned user is kicked and may not rejoin until unbanned
registration remembers parameters
registration accepts non-ascii passwords
registration with inhibit_login inhibits login
The operation must be consistent through an interactive authentication session
Multiple calls to /sync should not cause 500 errors