Merge remote-tracking branch 'origin/main' into loginsso

This commit is contained in:
Tommie Gannert 2022-06-08 09:24:56 +02:00
commit 4da104c5c3
49 changed files with 1168 additions and 1126 deletions

View file

@ -1,5 +1,36 @@
# Changelog # Changelog
## Dendrite 0.8.7 (2022-06-01)
### Features
* Support added for room version 10
### Fixes
* A number of state handling bugs have been fixed, which previously resulted in missing state events, unexpected state deletions, reverted memberships and unexpectedly rejected/soft-failed events in some specific cases
* Fixed destination queue performance issues as a result of missing indexes, which speeds up outbound federation considerably
* A bug which could cause the `/register` endpoint to return HTTP 500 has been fixed
## Dendrite 0.8.6 (2022-05-26)
### Features
* Room versions 8 and 9 are now marked as stable
* Dendrite can now assist remote users to join restricted rooms via `/make_join` and `/send_join`
### Fixes
* The sync API no longer returns immediately on `/sync` requests unnecessarily if it can be avoided
* A race condition has been fixed in the sync API when updating presence via `/sync`
* A race condition has been fixed sending E2EE keys to remote servers over federation when joining rooms
* The `trusted_private_chat` preset should now grant power level 100 to all participant users, which should improve the user experience of direct messages
* Invited users are now authed correctly in restricted rooms
* The `join_authorised_by_users_server` key is now correctly stripped in restricted rooms when updating the membership event
* Appservices should now receive invite events correctly
* Device list updates should no longer contain optional fields with `null` values
* The `/deactivate` endpoint has been fixed to no longer confuse Element with incorrect completed flows
## Dendrite 0.8.5 (2022-05-13) ## Dendrite 0.8.5 (2022-05-13)
### Features ### Features

View file

@ -96,10 +96,9 @@ than features that massive deployments may be interested in (User Directory, Ope
This means Dendrite supports amongst others: This means Dendrite supports amongst others:
- Core room functionality (creating rooms, invites, auth rules) - Core room functionality (creating rooms, invites, auth rules)
- Full support for room versions 1 to 7 - Room versions 1 to 10 supported
- Experimental support for room versions 8 to 9
- Backfilling locally and via federation - Backfilling locally and via federation
- Accounts, Profiles and Devices - Accounts, profiles and devices
- Published room lists - Published room lists
- Typing - Typing
- Media APIs - Media APIs

View file

@ -83,29 +83,38 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg)
return true return true
} }
if output.Type != api.OutputTypeNewRoomEvent || output.NewRoomEvent == nil { log.WithFields(log.Fields{
return true "type": output.Type,
} }).Debug("Got a message in OutputRoomEventConsumer")
newEventID := output.NewRoomEvent.Event.EventID() events := []*gomatrixserverlib.HeaderedEvent{}
events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(output.NewRoomEvent.AddsStateEventIDs)) if output.Type == api.OutputTypeNewRoomEvent && output.NewRoomEvent != nil {
events = append(events, output.NewRoomEvent.Event) newEventID := output.NewRoomEvent.Event.EventID()
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { events = append(events, output.NewRoomEvent.Event)
eventsReq := &api.QueryEventsByIDRequest{ if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), eventsReq := &api.QueryEventsByIDRequest{
} EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
eventsRes := &api.QueryEventsByIDResponse{} }
for _, eventID := range output.NewRoomEvent.AddsStateEventIDs { eventsRes := &api.QueryEventsByIDResponse{}
if eventID != newEventID { for _, eventID := range output.NewRoomEvent.AddsStateEventIDs {
eventsReq.EventIDs = append(eventsReq.EventIDs, eventID) if eventID != newEventID {
eventsReq.EventIDs = append(eventsReq.EventIDs, eventID)
}
}
if len(eventsReq.EventIDs) > 0 {
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil {
return false
}
events = append(events, eventsRes.Events...)
} }
} }
if len(eventsReq.EventIDs) > 0 { } else if output.Type == api.OutputTypeNewInviteEvent && output.NewInviteEvent != nil {
if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { events = append(events, output.NewInviteEvent.Event)
return false } else {
} log.WithFields(log.Fields{
events = append(events, eventsRes.Events...) "type": output.Type,
} }).Debug("appservice OutputRoomEventConsumer ignoring event", string(msg.Data))
return true
} }
// Send event to any relevant application services // Send event to any relevant application services

View file

@ -261,7 +261,7 @@ func (m *DendriteMonolith) Start() {
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err := cfg.Derive(); err != nil { if err = cfg.Derive(); err != nil {
panic(err) panic(err)
} }
@ -342,11 +342,23 @@ func (m *DendriteMonolith) Start() {
go func() { go func() {
m.logger.Info("Listening on ", cfg.Global.ServerName) m.logger.Info("Listening on ", cfg.Global.ServerName)
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))
switch m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")) {
case net.ErrClosed, http.ErrServerClosed:
m.logger.Info("Stopped listening on ", cfg.Global.ServerName)
default:
m.logger.Fatal(err)
}
}() }()
go func() { go func() {
logrus.Info("Listening on ", m.listener.Addr()) logrus.Info("Listening on ", m.listener.Addr())
logrus.Fatal(http.Serve(m.listener, httpRouter))
switch http.Serve(m.listener, httpRouter) {
case net.ErrClosed, http.ErrServerClosed:
m.logger.Info("Stopped listening on ", cfg.Global.ServerName)
default:
m.logger.Fatal(err)
}
}() }()
} }

View file

@ -170,11 +170,11 @@ func (m *DendriteMonolith) Start() {
go func() { go func() {
m.logger.Info("Listening on ", ygg.DerivedServerName()) m.logger.Info("Listening on ", ygg.DerivedServerName())
m.logger.Fatal(m.httpServer.Serve(ygg)) m.logger.Error(m.httpServer.Serve(ygg))
}() }()
go func() { go func() {
logrus.Info("Listening on ", m.listener.Addr()) logrus.Info("Listening on ", m.listener.Addr())
logrus.Fatal(http.Serve(m.listener, httpRouter)) logrus.Error(http.Serve(m.listener, httpRouter))
}() }()
go func() { go func() {
logrus.Info("Sending wake-up message to known nodes") logrus.Info("Sending wake-up message to known nodes")

View file

@ -154,6 +154,12 @@ func MissingParam(msg string) *MatrixError {
return &MatrixError{"M_MISSING_PARAM", msg} return &MatrixError{"M_MISSING_PARAM", msg}
} }
// UnableToAuthoriseJoin is an error that is returned when a server can't
// determine whether to allow a restricted join or not.
func UnableToAuthoriseJoin(msg string) *MatrixError {
return &MatrixError{"M_UNABLE_TO_AUTHORISE_JOIN", msg}
}
// LeaveServerNoticeError is an error returned when trying to reject an invite // LeaveServerNoticeError is an error returned when trying to reject an invite
// for a server notice room. // for a server notice room.
func LeaveServerNoticeError() *MatrixError { func LeaveServerNoticeError() *MatrixError {

View file

@ -245,7 +245,9 @@ func createRoom(
case presetTrustedPrivateChat: case presetTrustedPrivateChat:
joinRuleContent.JoinRule = gomatrixserverlib.Invite joinRuleContent.JoinRule = gomatrixserverlib.Invite
historyVisibilityContent.HistoryVisibility = historyVisibilityShared historyVisibilityContent.HistoryVisibility = historyVisibilityShared
// TODO If trusted_private_chat, all invitees are given the same power level as the room creator. for _, invitee := range r.Invite {
powerLevelContent.Users[invitee] = 100
}
case presetPublicChat: case presetPublicChat:
joinRuleContent.JoinRule = gomatrixserverlib.Public joinRuleContent.JoinRule = gomatrixserverlib.Public
historyVisibilityContent.HistoryVisibility = historyVisibilityShared historyVisibilityContent.HistoryVisibility = historyVisibilityShared

View file

@ -101,6 +101,9 @@ func Setup(
"r0.4.0", "r0.4.0",
"r0.5.0", "r0.5.0",
"r0.6.1", "r0.6.1",
"v1.0",
"v1.1",
"v1.2",
}, UnstableFeatures: unstableFeatures}, }, UnstableFeatures: unstableFeatures},
} }
}), }),
@ -149,7 +152,7 @@ func Setup(
synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}",
httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
// not specced, but ensure we're rate limiting requests to this endpoint // not specced, but ensure we're rate limiting requests to this endpoint
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -169,7 +172,7 @@ func Setup(
synapseAdminRouter.Handle("/admin/v1/send_server_notice", synapseAdminRouter.Handle("/admin/v1/send_server_notice",
httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
// not specced, but ensure we're rate limiting requests to this endpoint // not specced, but ensure we're rate limiting requests to this endpoint
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
return SendServerNotice( return SendServerNotice(
@ -199,7 +202,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/join/{roomIDOrAlias}", v3mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -215,7 +218,7 @@ func Setup(
if mscCfg.Enabled("msc2753") { if mscCfg.Enabled("msc2753") {
v3mux.Handle("/peek/{roomIDOrAlias}", v3mux.Handle("/peek/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -235,7 +238,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/join", v3mux.Handle("/rooms/{roomID}/join",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -249,7 +252,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/leave", v3mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -283,7 +286,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/invite", v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -401,14 +404,14 @@ func Setup(
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, nil); r != nil {
return *r return *r
} }
return Register(req, userAPI, cfg) return Register(req, userAPI, cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, nil); r != nil {
return *r return *r
} }
return RegisterAvailable(req, cfg, userAPI) return RegisterAvailable(req, cfg, userAPI)
@ -482,7 +485,7 @@ func Setup(
v3mux.Handle("/rooms/{roomID}/typing/{userID}", v3mux.Handle("/rooms/{roomID}/typing/{userID}",
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -539,7 +542,7 @@ func Setup(
v3mux.Handle("/account/whoami", v3mux.Handle("/account/whoami",
httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
return Whoami(req, device) return Whoami(req, device)
@ -548,7 +551,7 @@ func Setup(
v3mux.Handle("/account/password", v3mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
return Password(req, userAPI, device, cfg) return Password(req, userAPI, device, cfg)
@ -557,7 +560,7 @@ func Setup(
v3mux.Handle("/account/deactivate", v3mux.Handle("/account/deactivate",
httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
return Deactivate(req, userInteractiveAuth, userAPI, device) return Deactivate(req, userInteractiveAuth, userAPI, device)
@ -568,7 +571,7 @@ func Setup(
v3mux.Handle("/login", v3mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, nil); r != nil {
return *r return *r
} }
return Login(req, userAPI, cfg) return Login(req, userAPI, cfg)
@ -695,7 +698,7 @@ func Setup(
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -761,7 +764,7 @@ func Setup(
v3mux.Handle("/profile/{userID}/avatar_url", v3mux.Handle("/profile/{userID}/avatar_url",
httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -786,7 +789,7 @@ func Setup(
v3mux.Handle("/profile/{userID}/displayname", v3mux.Handle("/profile/{userID}/displayname",
httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -825,7 +828,7 @@ func Setup(
v3mux.Handle("/voip/turnServer", v3mux.Handle("/voip/turnServer",
httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
return RequestTurnServer(req, device, cfg) return RequestTurnServer(req, device, cfg)
@ -904,7 +907,7 @@ func Setup(
v3mux.Handle("/user/{userID}/openid/request_token", v3mux.Handle("/user/{userID}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -917,7 +920,7 @@ func Setup(
v3mux.Handle("/user_directory/search", v3mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
postContent := struct { postContent := struct {
@ -963,7 +966,7 @@ func Setup(
v3mux.Handle("/rooms/{roomID}/read_markers", v3mux.Handle("/rooms/{roomID}/read_markers",
httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -976,7 +979,7 @@ func Setup(
v3mux.Handle("/rooms/{roomID}/forget", v3mux.Handle("/rooms/{roomID}/forget",
httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1053,7 +1056,7 @@ func Setup(
v3mux.Handle("/pushers/set", v3mux.Handle("/pushers/set",
httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
return SetPusher(req, device, userAPI) return SetPusher(req, device, userAPI)
@ -1111,7 +1114,7 @@ func Setup(
v3mux.Handle("/capabilities", v3mux.Handle("/capabilities",
httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
return GetCapabilities(req, rsAPI) return GetCapabilities(req, rsAPI)
@ -1327,7 +1330,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))

View file

@ -19,6 +19,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"sync" "sync"
"time" "time"
@ -96,14 +97,28 @@ func SendEvent(
mutex.(*sync.Mutex).Lock() mutex.(*sync.Mutex).Lock()
defer mutex.(*sync.Mutex).Unlock() defer mutex.(*sync.Mutex).Unlock()
startedGeneratingEvent := time.Now()
var r map[string]interface{} // must be a JSON object var r map[string]interface{} // must be a JSON object
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
if stateKey != nil {
// If the existing/new state content are equal, return the existing event_id, making the request idempotent.
if resp := stateEqual(req.Context(), rsAPI, eventType, *stateKey, roomID, r); resp != nil {
return *resp
}
}
startedGeneratingEvent := time.Now()
// If we're sending a membership update, make sure to strip the authorised
// via key if it is present, otherwise other servers won't be able to auth
// the event if the room is set to the "restricted" join rule.
if eventType == gomatrixserverlib.MRoomMember {
delete(r, "join_authorised_via_users_server")
}
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -201,6 +216,37 @@ func SendEvent(
return res return res
} }
// stateEqual compares the new and the existing state event content. If they are equal, returns a *util.JSONResponse
// with the existing event_id, making this an idempotent request.
func stateEqual(ctx context.Context, rsAPI api.ClientRoomserverAPI, eventType, stateKey, roomID string, newContent map[string]interface{}) *util.JSONResponse {
stateRes := api.QueryCurrentStateResponse{}
tuple := gomatrixserverlib.StateKeyTuple{
EventType: eventType,
StateKey: stateKey,
}
err := rsAPI.QueryCurrentState(ctx, &api.QueryCurrentStateRequest{
RoomID: roomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
}, &stateRes)
if err != nil {
return nil
}
if existingEvent, ok := stateRes.StateEvents[tuple]; ok {
var existingContent map[string]interface{}
if err = json.Unmarshal(existingEvent.Content(), &existingContent); err != nil {
return nil
}
if reflect.DeepEqual(existingContent, newContent) {
return &util.JSONResponse{
Code: http.StatusOK,
JSON: sendEventResponse{existingEvent.EventID()},
}
}
}
return nil
}
func generateSendEvent( func generateSendEvent(
ctx context.Context, ctx context.Context,
r map[string]interface{}, r map[string]interface{},

View file

@ -4,13 +4,17 @@ import (
"context" "context"
"flag" "flag"
"fmt" "fmt"
"os" "sort"
"strconv" "strconv"
"strings"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -23,11 +27,17 @@ import (
// e.g. ./resolve-state --roomversion=5 1254 1235 1282 // e.g. ./resolve-state --roomversion=5 1254 1235 1282
var roomVersion = flag.String("roomversion", "5", "the room version to parse events as") var roomVersion = flag.String("roomversion", "5", "the room version to parse events as")
var filterType = flag.String("filtertype", "", "the event types to filter on")
func main() { func main() {
ctx := context.Background() ctx := context.Background()
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
args := os.Args[1:] cfg.Logging = append(cfg.Logging[:0], config.LogrusHook{
Type: "std",
Level: "error",
})
base := base.NewBaseDendrite(cfg, "ResolveState", base.DisableMetrics)
args := flag.Args()
fmt.Println("Room version", *roomVersion) fmt.Println("Room version", *roomVersion)
@ -45,30 +55,28 @@ func main() {
panic(err) panic(err)
} }
roomserverDB, err := storage.Open(nil, &cfg.RoomServer.Database, cache) roomserverDB, err := storage.Open(base, &cfg.RoomServer.Database, cache)
if err != nil { if err != nil {
panic(err) panic(err)
} }
blockNIDs, err := roomserverDB.StateBlockNIDs(ctx, snapshotNIDs) stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{
if err != nil { RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
panic(err) })
}
var stateEntries []types.StateEntryList var stateEntries []types.StateEntry
for _, list := range blockNIDs { for _, snapshotNID := range snapshotNIDs {
entries, err2 := roomserverDB.StateEntries(ctx, list.StateBlockNIDs) var entries []types.StateEntry
if err2 != nil { entries, err = stateres.LoadStateAtSnapshot(ctx, snapshotNID)
panic(err2) if err != nil {
panic(err)
} }
stateEntries = append(stateEntries, entries...) stateEntries = append(stateEntries, entries...)
} }
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
for _, entry := range stateEntries { for _, entry := range stateEntries {
for _, e := range entry.StateEntries { eventNIDs = append(eventNIDs, entry.EventNID)
eventNIDs = append(eventNIDs, e.EventNID)
}
} }
fmt.Println("Fetching", len(eventNIDs), "state events") fmt.Println("Fetching", len(eventNIDs), "state events")
@ -103,7 +111,8 @@ func main() {
} }
fmt.Println("Resolving state") fmt.Println("Resolving state")
resolved, err := gomatrixserverlib.ResolveConflicts( var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), gomatrixserverlib.RoomVersion(*roomVersion),
events, events,
authEvents, authEvents,
@ -113,9 +122,41 @@ func main() {
} }
fmt.Println("Resolved state contains", len(resolved), "events") fmt.Println("Resolved state contains", len(resolved), "events")
sort.Sort(resolved)
filteringEventType := *filterType
count := 0
for _, event := range resolved { for _, event := range resolved {
if filteringEventType != "" && event.Type() != filteringEventType {
continue
}
count++
fmt.Println() fmt.Println()
fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey()) fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey())
fmt.Printf(" %s\n", string(event.Content())) fmt.Printf(" %s\n", string(event.Content()))
} }
fmt.Println()
fmt.Println("Returned", count, "state events after filtering")
}
type Events []*gomatrixserverlib.Event
func (e Events) Len() int {
return len(e)
}
func (e Events) Swap(i, j int) {
e[i], e[j] = e[j], e[i]
}
func (e Events) Less(i, j int) bool {
typeDelta := strings.Compare(e[i].Type(), e[j].Type())
if typeDelta < 0 {
return true
}
if typeDelta > 0 {
return false
}
stateKeyDelta := strings.Compare(*e[i].StateKey(), *e[j].StateKey())
return stateKeyDelta < 0
} }

View file

@ -160,11 +160,14 @@ client_api:
# Settings for rate-limited endpoints. Rate limiting kicks in after the threshold # Settings for rate-limited endpoints. Rate limiting kicks in after the threshold
# number of "slots" have been taken by requests from a specific host. Each "slot" # number of "slots" have been taken by requests from a specific host. Each "slot"
# will be released after the cooloff time in milliseconds. # will be released after the cooloff time in milliseconds. Server administrators
# and appservice users are exempt from rate limiting by default.
rate_limiting: rate_limiting:
enabled: true enabled: true
threshold: 5 threshold: 5
cooloff_ms: 500 cooloff_ms: 500
exempt_user_ids:
# - @user:domain.com
# Configuration for the Federation API. # Configuration for the Federation API.
federation_api: federation_api:

View file

@ -163,11 +163,14 @@ client_api:
# Settings for rate-limited endpoints. Rate limiting kicks in after the threshold # Settings for rate-limited endpoints. Rate limiting kicks in after the threshold
# number of "slots" have been taken by requests from a specific host. Each "slot" # number of "slots" have been taken by requests from a specific host. Each "slot"
# will be released after the cooloff time in milliseconds. # will be released after the cooloff time in milliseconds. Server administrators
# and appservice users are exempt from rate limiting by default.
rate_limiting: rate_limiting:
enabled: true enabled: true
threshold: 5 threshold: 5
cooloff_ms: 500 cooloff_ms: 500
exempt_user_ids:
# - @user:domain.com
# Configuration for the Federation API. # Configuration for the Federation API.
federation_api: federation_api:

View file

@ -9,21 +9,19 @@ permalink: /installation/planning
## Modes ## Modes
Dendrite can be run in one of two configurations: Dendrite consists of several components, each responsible for a different aspect of the Matrix protocol.
Users can run Dendrite in one of two modes which dictate how these components are executed and communicate.
* **Monolith mode**: All components run in the same process. In this mode, * **Monolith mode** runs all components in a single process. Components communicate through an internal NATS
it is possible to run an in-process NATS Server instead of running a standalone deployment. server with generally low overhead. This mode dramatically simplifies deployment complexity and offers the
This will usually be the preferred model for low-to-mid volume deployments, providing the best best balance between performance and resource usage for low-to-mid volume deployments.
balance between performance and resource usage.
* **Polylith mode**: A cluster of individual components running in their own processes, dealing * **Polylith mode** runs all components in isolated processes. Components communicate through an external NATS
with different aspects of the Matrix protocol. Components communicate with each other using server and HTTP APIs, which incur considerable overhead. While this mode allows for more granular control of
internal HTTP APIs and NATS Server. This will almost certainly be the preferred model for very resources dedicated toward individual processes, given the additional communications overhead, it is only
large deployments but scalability comes with a cost. API calls are expensive and therefore a necessary for very large deployments.
polylith deployment may end up using disproportionately more resources for a smaller number of
users compared to a monolith deployment.
At present, we **recommend monolith mode deployments** in all cases. Given our current state of development, **we recommend monolith mode** for all deployments.
## Databases ## Databases

View file

@ -55,60 +55,62 @@ var servers = map[string]*server{
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Set up the server key API for each "server" that we // Set up the server key API for each "server" that we
// will use in our tests. // will use in our tests.
for _, s := range servers { os.Exit(func() int {
// Generate a new key. for _, s := range servers {
_, testPriv, err := ed25519.GenerateKey(nil) // Generate a new key.
if err != nil { _, testPriv, err := ed25519.GenerateKey(nil)
panic("can't generate identity key: " + err.Error()) if err != nil {
panic("can't generate identity key: " + err.Error())
}
// Create a new cache but don't enable prometheus!
s.cache, err = caching.NewInMemoryLRUCache(false)
if err != nil {
panic("can't create cache: " + err.Error())
}
// Create a temporary directory for JetStream.
d, err := ioutil.TempDir("./", "jetstream*")
if err != nil {
panic(err)
}
defer os.RemoveAll(d)
// Draw up just enough Dendrite config for the server key
// API to work.
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = gomatrixserverlib.ServerName(s.name)
cfg.Global.PrivateKey = testPriv
cfg.Global.JetStream.InMemory = true
cfg.Global.JetStream.TopicPrefix = string(s.name[:1])
cfg.Global.JetStream.StoragePath = config.Path(d)
cfg.Global.KeyID = serverKeyID
cfg.Global.KeyValidityPeriod = s.validity
cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:")
s.config = &cfg.FederationAPI
// Create a transport which redirects federation requests to
// the mock round tripper. Since we're not *really* listening for
// federation requests then this will return the key instead.
transport := &http.Transport{}
transport.RegisterProtocol("matrix", &MockRoundTripper{})
// Create the federation client.
s.fedclient = gomatrixserverlib.NewFederationClient(
s.config.Matrix.ServerName, serverKeyID, testPriv,
gomatrixserverlib.WithTransport(transport),
)
// Finally, build the server key APIs.
sbase := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics)
s.api = NewInternalAPI(sbase, s.fedclient, nil, s.cache, nil, true)
} }
// Create a new cache but don't enable prometheus! // Now that we have built our server key APIs, start the
s.cache, err = caching.NewInMemoryLRUCache(false) // rest of the tests.
if err != nil { return m.Run()
panic("can't create cache: " + err.Error()) }())
}
// Create a temporary directory for JetStream.
d, err := ioutil.TempDir("./", "jetstream*")
if err != nil {
panic(err)
}
defer os.RemoveAll(d)
// Draw up just enough Dendrite config for the server key
// API to work.
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = gomatrixserverlib.ServerName(s.name)
cfg.Global.PrivateKey = testPriv
cfg.Global.JetStream.InMemory = true
cfg.Global.JetStream.TopicPrefix = string(s.name[:1])
cfg.Global.JetStream.StoragePath = config.Path(d)
cfg.Global.KeyID = serverKeyID
cfg.Global.KeyValidityPeriod = s.validity
cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:")
s.config = &cfg.FederationAPI
// Create a transport which redirects federation requests to
// the mock round tripper. Since we're not *really* listening for
// federation requests then this will return the key instead.
transport := &http.Transport{}
transport.RegisterProtocol("matrix", &MockRoundTripper{})
// Create the federation client.
s.fedclient = gomatrixserverlib.NewFederationClient(
s.config.Matrix.ServerName, serverKeyID, testPriv,
gomatrixserverlib.WithTransport(transport),
)
// Finally, build the server key APIs.
sbase := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics)
s.api = NewInternalAPI(sbase, s.fedclient, nil, s.cache, nil, true)
}
// Now that we have built our server key APIs, start the
// rest of the tests.
os.Exit(m.Run())
} }
type MockRoundTripper struct{} type MockRoundTripper struct{}

View file

@ -166,7 +166,8 @@ func (r *FederationInternalAPI) performJoinUsingServer(
if content == nil { if content == nil {
content = map[string]interface{}{} content = map[string]interface{}{}
} }
content["membership"] = "join" _ = json.Unmarshal(respMakeJoin.JoinEvent.Content, &content)
content["membership"] = gomatrixserverlib.Join
if err = respMakeJoin.JoinEvent.SetContent(content); err != nil { if err = respMakeJoin.JoinEvent.SetContent(content); err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err) return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err)
} }
@ -209,10 +210,22 @@ func (r *FederationInternalAPI) performJoinUsingServer(
} }
r.statistics.ForServer(serverName).Success() r.statistics.ForServer(serverName).Success()
authEvents := respSendJoin.AuthEvents.UntrustedEvents(respMakeJoin.RoomVersion) // If the remote server returned an event in the "event" key of
// the send_join request then we should use that instead. It may
// contain signatures that we don't know about.
if len(respSendJoin.Event) > 0 {
var remoteEvent *gomatrixserverlib.Event
remoteEvent, err = respSendJoin.Event.UntrustedEvent(respMakeJoin.RoomVersion)
if err == nil && isWellFormedMembershipEvent(
remoteEvent, roomID, userID, r.cfg.Matrix.ServerName,
) {
event = remoteEvent
}
}
// Sanity-check the join response to ensure that it has a create // Sanity-check the join response to ensure that it has a create
// event, that the room version is known, etc. // event, that the room version is known, etc.
authEvents := respSendJoin.AuthEvents.UntrustedEvents(respMakeJoin.RoomVersion)
if err = sanityCheckAuthChain(authEvents); err != nil { if err = sanityCheckAuthChain(authEvents); err != nil {
return fmt.Errorf("sanityCheckAuthChain: %w", err) return fmt.Errorf("sanityCheckAuthChain: %w", err)
} }
@ -270,6 +283,26 @@ func (r *FederationInternalAPI) performJoinUsingServer(
return nil return nil
} }
// isWellFormedMembershipEvent returns true if the event looks like a legitimate
// membership event.
func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string, origin gomatrixserverlib.ServerName) bool {
if membership, err := event.Membership(); err != nil {
return false
} else if membership != gomatrixserverlib.Join {
return false
}
if event.RoomID() != roomID {
return false
}
if event.Origin() != origin {
return false
}
if !event.StateKeyEquals(userID) {
return false
}
return true
}
// PerformOutboundPeekRequest implements api.FederationInternalAPI // PerformOutboundPeekRequest implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformOutboundPeek( func (r *FederationInternalAPI) PerformOutboundPeek(
ctx context.Context, ctx context.Context,

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"sort" "sort"
@ -103,6 +104,16 @@ func MakeJoin(
} }
} }
// Check if the restricted join is allowed. If the room doesn't
// support restricted joins then this is effectively a no-op.
res, authorisedVia, err := checkRestrictedJoin(httpReq, rsAPI, verRes.RoomVersion, roomID, userID)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("checkRestrictedJoin failed")
return jsonerror.InternalServerError()
} else if res != nil {
return *res
}
// Try building an event for the server // Try building an event for the server
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
Sender: userID, Sender: userID,
@ -110,8 +121,11 @@ func MakeJoin(
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &userID,
} }
err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Join}) content := gomatrixserverlib.MemberContent{
if err != nil { Membership: gomatrixserverlib.Join,
AuthorisedVia: authorisedVia,
}
if err = builder.SetContent(content); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed") util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -161,6 +175,7 @@ func MakeJoin(
// SendJoin implements the /send_join API // SendJoin implements the /send_join API
// The make-join send-join dance makes much more sense as a single // The make-join send-join dance makes much more sense as a single
// flow so the cyclomatic complexity is high: // flow so the cyclomatic complexity is high:
// nolint:gocyclo
func SendJoin( func SendJoin(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
@ -314,6 +329,40 @@ func SendJoin(
} }
} }
// If the membership content contains a user ID for a server that is not
// ours then we should kick it back.
var memberContent gomatrixserverlib.MemberContent
if err := json.Unmarshal(event.Content(), &memberContent); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
}
if memberContent.AuthorisedVia != "" {
_, domain, err := gomatrixserverlib.SplitID('@', memberContent.AuthorisedVia)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("The authorising username %q is invalid.", memberContent.AuthorisedVia)),
}
}
if domain != cfg.Matrix.ServerName {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("The authorising username %q does not belong to this server.", memberContent.AuthorisedVia)),
}
}
}
// Sign the membership event. This is required for restricted joins to work
// in the case that the authorised via user is one of our own users. It also
// doesn't hurt to do it even if it isn't a restricted join.
signed := event.Sign(
string(cfg.Matrix.ServerName),
cfg.Matrix.KeyID,
cfg.Matrix.PrivateKey,
)
// Send the events to the room server. // Send the events to the room server.
// We are responsible for notifying other servers that the user has joined // We are responsible for notifying other servers that the user has joined
// the room, so set SendAsServer to cfg.Matrix.ServerName // the room, so set SendAsServer to cfg.Matrix.ServerName
@ -323,7 +372,7 @@ func SendJoin(
InputRoomEvents: []api.InputRoomEvent{ InputRoomEvents: []api.InputRoomEvent{
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event.Headered(stateAndAuthChainResponse.RoomVersion), Event: signed.Headered(stateAndAuthChainResponse.RoomVersion),
SendAsServer: string(cfg.Matrix.ServerName), SendAsServer: string(cfg.Matrix.ServerName),
TransactionID: nil, TransactionID: nil,
}, },
@ -354,10 +403,77 @@ func SendJoin(
StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents),
AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents),
Origin: cfg.Matrix.ServerName, Origin: cfg.Matrix.ServerName,
Event: signed.JSON(),
}, },
} }
} }
// checkRestrictedJoin finds out whether or not we can assist in processing
// a restricted room join. If the room version does not support restricted
// joins then this function returns with no side effects. This returns three
// values:
// * an optional JSON response body (i.e. M_UNABLE_TO_AUTHORISE_JOIN) which
// should always be sent back to the client if one is specified
// * a user ID of an authorising user, typically a user that has power to
// issue invites in the room, if one has been found
// * an error if there was a problem finding out if this was allowable,
// like if the room version isn't known or a problem happened talking to
// the roomserver
func checkRestrictedJoin(
httpReq *http.Request,
rsAPI api.FederationRoomserverAPI,
roomVersion gomatrixserverlib.RoomVersion,
roomID, userID string,
) (*util.JSONResponse, string, error) {
if allowRestricted, err := roomVersion.MayAllowRestrictedJoinsInEventAuth(); err != nil {
return nil, "", err
} else if !allowRestricted {
return nil, "", nil
}
req := &api.QueryRestrictedJoinAllowedRequest{
RoomID: roomID,
UserID: userID,
}
res := &api.QueryRestrictedJoinAllowedResponse{}
if err := rsAPI.QueryRestrictedJoinAllowed(httpReq.Context(), req, res); err != nil {
return nil, "", err
}
switch {
case !res.Restricted:
// The join rules for the room don't restrict membership.
return nil, "", nil
case !res.Resident:
// The join rules restrict membership but our server isn't currently
// joined to all of the allowed rooms, so we can't actually decide
// whether or not to allow the user to join. This error code should
// tell the joining server to try joining via another resident server
// instead.
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UnableToAuthoriseJoin("This server cannot authorise the join."),
}, "", nil
case !res.Allowed:
// The join rules restrict membership, our server is in the relevant
// rooms and the user wasn't joined to join any of the allowed rooms
// and therefore can't join this room.
return &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("You are not joined to any matching rooms."),
}, "", nil
default:
// The join rules restrict membership, our server is in the relevant
// rooms and the user was allowed to join because they belong to one
// of the allowed rooms. We now need to pick one of our own local users
// from within the room to use as the authorising user ID, so that it
// can be referred to from within the membership content.
return nil, res.AuthorisedVia, nil
}
}
type eventsByDepth []*gomatrixserverlib.HeaderedEvent type eventsByDepth []*gomatrixserverlib.HeaderedEvent
func (e eventsByDepth) Len() int { func (e eventsByDepth) Len() int {

View file

@ -36,6 +36,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name); ON federationsender_queue_edus (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_edus_nid_idx
ON federationsender_queue_edus (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx
ON federationsender_queue_edus (server_name);
` `
const insertQueueEDUSQL = "" + const insertQueueEDUSQL = "" +

View file

@ -33,6 +33,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_json (
-- The JSON body. Text so that we preserve UTF-8. -- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL json_body TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_json_json_nid_idx
ON federationsender_queue_json (json_nid);
` `
const insertJSONSQL = "" + const insertJSONSQL = "" +

View file

@ -36,6 +36,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid, server_name); ON federationsender_queue_pdus (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_server_name_idx
ON federationsender_queue_pdus (server_name);
` `
const insertQueuePDUSQL = "" + const insertQueuePDUSQL = "" +

View file

@ -37,6 +37,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name); ON federationsender_queue_edus (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_edus_nid_idx
ON federationsender_queue_edus (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx
ON federationsender_queue_edus (server_name);
` `
const insertQueueEDUSQL = "" + const insertQueueEDUSQL = "" +

View file

@ -35,6 +35,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_json (
-- The JSON body. Text so that we preserve UTF-8. -- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL json_body TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_json_json_nid_idx
ON federationsender_queue_json (json_nid);
` `
const insertJSONSQL = "" + const insertJSONSQL = "" +

View file

@ -38,6 +38,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid, server_name); ON federationsender_queue_pdus (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_server_name_idx
ON federationsender_queue_pdus (server_name);
` `
const insertQueuePDUSQL = "" + const insertQueuePDUSQL = "" +

47
go.mod
View file

@ -5,20 +5,24 @@ replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-serve
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e
require ( require (
github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0 github.com/Arceliar/ironwood v0.0.0-20220306165321-319147a02d98
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect
github.com/MFAshby/stdemuxerhook v1.0.0 github.com/MFAshby/stdemuxerhook v1.0.0
github.com/Masterminds/semver/v3 v3.1.1 github.com/Masterminds/semver/v3 v3.1.1
github.com/Microsoft/go-winio v0.5.1 // indirect
github.com/codeclysm/extract v2.2.0+incompatible github.com/codeclysm/extract v2.2.0+incompatible
github.com/containerd/containerd v1.6.2 // indirect github.com/docker/distribution v2.7.1+incompatible // indirect
github.com/docker/docker v20.10.14+incompatible github.com/docker/docker v20.10.16+incompatible
github.com/docker/go-connections v0.4.0 github.com/docker/go-connections v0.4.0
github.com/docker/go-units v0.4.0 // indirect
github.com/frankban/quicktest v1.14.3 // indirect github.com/frankban/quicktest v1.14.3 // indirect
github.com/getsentry/sentry-go v0.13.0 github.com/getsentry/sentry-go v0.13.0
github.com/gogo/protobuf v1.3.2 // indirect
github.com/gologme/log v1.3.0 github.com/gologme/log v1.3.0
github.com/google/go-cmp v0.5.7 github.com/google/go-cmp v0.5.8
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
@ -30,38 +34,47 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220513103617-eee8fd528433 github.com/matrix-org/gomatrixserverlib v0.0.0-20220607143425-e55d796fd0b3
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10 github.com/mattn/go-sqlite3 v1.14.13
github.com/miekg/dns v1.1.31 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/miekg/dns v1.1.49 // indirect
github.com/moby/term v0.0.0-20210610120745-9d4ed1856297 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
github.com/nats-io/nats.go v1.14.0 github.com/nats-io/nats.go v1.14.0
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79
github.com/opencontainers/image-spec v1.0.2 // indirect github.com/onsi/gomega v1.17.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect
github.com/opentracing/opentracing-go v1.2.0 github.com/opentracing/opentracing-go v1.2.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pressly/goose v2.7.0+incompatible github.com/pressly/goose v2.7.0+incompatible
github.com/prometheus/client_golang v1.12.1 github.com/prometheus/client_golang v1.12.2
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.7.0 github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.7.1
github.com/tidwall/gjson v1.14.1 github.com/tidwall/gjson v1.14.1
github.com/tidwall/sjson v1.2.4 github.com/tidwall/sjson v1.2.4
github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-client-go v2.30.0+incompatible
github.com/uber/jaeger-lib v2.4.1+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.3 github.com/yggdrasil-network/yggdrasil-go v0.4.3
go.uber.org/atomic v1.9.0 go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122 golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2 golang.org/x/mobile v0.0.0-20220518205345-8578da9835fd
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 golang.org/x/net v0.0.0-20220524220425-1d687d428aca
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 // indirect golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.0 // indirect
gotest.tools/v3 v3.0.3 // indirect
nhooyr.io/websocket v1.8.7 nhooyr.io/websocket v1.8.7
) )

874
go.sum

File diff suppressed because it is too large Load diff

View file

@ -7,6 +7,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -17,6 +18,7 @@ type RateLimits struct {
enabled bool enabled bool
requestThreshold int64 requestThreshold int64
cooloffDuration time.Duration cooloffDuration time.Duration
exemptUserIDs map[string]struct{}
} }
func NewRateLimits(cfg *config.RateLimiting) *RateLimits { func NewRateLimits(cfg *config.RateLimiting) *RateLimits {
@ -25,6 +27,10 @@ func NewRateLimits(cfg *config.RateLimiting) *RateLimits {
enabled: cfg.Enabled, enabled: cfg.Enabled,
requestThreshold: cfg.Threshold, requestThreshold: cfg.Threshold,
cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond, cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond,
exemptUserIDs: map[string]struct{}{},
}
for _, userID := range cfg.ExemptUserIDs {
l.exemptUserIDs[userID] = struct{}{}
} }
if l.enabled { if l.enabled {
go l.clean() go l.clean()
@ -52,7 +58,7 @@ func (l *RateLimits) clean() {
} }
} }
func (l *RateLimits) Limit(req *http.Request) *util.JSONResponse { func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSONResponse {
// If rate limiting is disabled then do nothing. // If rate limiting is disabled then do nothing.
if !l.enabled { if !l.enabled {
return nil return nil
@ -67,9 +73,26 @@ func (l *RateLimits) Limit(req *http.Request) *util.JSONResponse {
// First of all, work out if X-Forwarded-For was sent to us. If not // First of all, work out if X-Forwarded-For was sent to us. If not
// then we'll just use the IP address of the caller. // then we'll just use the IP address of the caller.
caller := req.RemoteAddr var caller string
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" { if device != nil {
caller = forwardedFor switch device.AccountType {
case userapi.AccountTypeAdmin:
return nil // don't rate-limit server administrators
case userapi.AccountTypeAppService:
return nil // don't rate-limit appservice users
default:
if _, ok := l.exemptUserIDs[device.UserID]; ok {
// If the user is exempt from rate limiting then do nothing.
return nil
}
caller = device.UserID + device.ID
}
} else {
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
caller = forwardedFor
} else {
caller = req.RemoteAddr
}
} }
// Look up the caller's channel, if they have one. // Look up the caller's channel, if they have one.

View file

@ -37,7 +37,7 @@ type traceInterceptor struct {
sqlmw.NullInterceptor sqlmw.NullInterceptor
} }
func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.StmtQueryContext, query string, args []driver.NamedValue) (driver.Rows, error) { func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.StmtQueryContext, query string, args []driver.NamedValue) (context.Context, driver.Rows, error) {
startedAt := time.Now() startedAt := time.Now()
rows, err := stmt.QueryContext(ctx, args) rows, err := stmt.QueryContext(ctx, args)
@ -45,7 +45,7 @@ func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.St
logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args)
return rows, err return ctx, rows, err
} }
func (in *traceInterceptor) StmtExecContext(ctx context.Context, stmt driver.StmtExecContext, query string, args []driver.NamedValue) (driver.Result, error) { func (in *traceInterceptor) StmtExecContext(ctx context.Context, stmt driver.StmtExecContext, query string, args []driver.NamedValue) (driver.Result, error) {

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 8 VersionMinor = 8
VersionPatch = 5 VersionPatch = 7
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -374,7 +374,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
// fetch stale device lists // fetch stale device lists
userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
if err != nil { if err != nil {
logger.WithError(err).Error("failed to load stale device lists") logger.WithError(err).Error("Failed to load stale device lists")
return waitTime, true return waitTime, true
} }
failCount := 0 failCount := 0
@ -399,7 +399,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
} }
} else { } else {
waitTime = time.Hour waitTime = time.Hour
logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type") logger.WithError(err).WithField("user_id", userID).Debug("GetUserDevices returned unknown error type")
} }
continue continue
} }
@ -422,12 +422,12 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
} }
err = u.updateDeviceList(&res) err = u.updateDeviceList(&res)
if err != nil { if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it") logger.WithError(err).WithField("user_id", userID).Error("Fetched device list but failed to store/emit it")
failCount += 1 failCount += 1
} }
} }
if failCount > 0 { if failCount > 0 {
logger.WithField("total", len(userIDs)).WithField("failed", failCount).WithField("wait", waitTime).Error("failed to query device keys for some users") logger.WithField("total", len(userIDs)).WithField("failed", failCount).WithField("wait", waitTime).Warn("Failed to query device keys for some users")
} }
for _, userID := range userIDs { for _, userID := range userIDs {
// always clear the channel to unblock Update calls regardless of success/failure // always clear the channel to unblock Update calls regardless of success/failure

View file

@ -62,7 +62,7 @@ func Setup(
uploadHandler := httputil.MakeAuthAPI( uploadHandler := httputil.MakeAuthAPI(
"upload", userAPI, "upload", userAPI,
func(req *http.Request, dev *userapi.Device) util.JSONResponse { func(req *http.Request, dev *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, dev); r != nil {
return *r return *r
} }
return Upload(req, cfg, dev, db, activeThumbnailGeneration) return Upload(req, cfg, dev, db, activeThumbnailGeneration)
@ -70,7 +70,7 @@ func Setup(
) )
configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }
respondSize := &cfg.MaxFileSizeBytes respondSize := &cfg.MaxFileSizeBytes
@ -126,7 +126,7 @@ func makeDownloadAPI(
// Ratelimit requests // Ratelimit requests
// NOTSPEC: The spec says everything at /media/ should be rate limited, but this causes issues with thumbnails (#2243) // NOTSPEC: The spec says everything at /media/ should be rate limited, but this causes issues with thumbnails (#2243)
if name != "thumbnail" { if name != "thumbnail" {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req, nil); r != nil {
if err := json.NewEncoder(w).Encode(r); err != nil { if err := json.NewEncoder(w).Encode(r); err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return

View file

@ -184,6 +184,7 @@ type FederationRoomserverAPI interface {
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.

View file

@ -354,6 +354,16 @@ func (t *RoomserverInternalAPITrace) QueryAuthChain(
return err return err
} }
func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed(
ctx context.Context,
request *QueryRestrictedJoinAllowedRequest,
response *QueryRestrictedJoinAllowedResponse,
) error {
err := t.Impl.QueryRestrictedJoinAllowed(ctx, request, response)
util.GetLogger(ctx).WithError(err).Infof("QueryRestrictedJoinAllowed req=%+v res=%+v", js(request), js(response))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)
if err != nil { if err != nil {

View file

@ -348,6 +348,26 @@ type QueryServerBannedFromRoomResponse struct {
Banned bool `json:"banned"` Banned bool `json:"banned"`
} }
type QueryRestrictedJoinAllowedRequest struct {
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
}
type QueryRestrictedJoinAllowedResponse struct {
// True if the room membership is restricted by the join rule being set to "restricted"
Restricted bool `json:"restricted"`
// True if our local server is joined to all of the allowed rooms specified in the "allow"
// key of the join rule, false if we are missing from some of them and therefore can't
// reliably decide whether or not we can satisfy the join
Resident bool `json:"resident"`
// True if the restricted join is allowed because we found the membership in one of the
// allowed rooms from the join rule, false if not
Allowed bool `json:"allowed"`
// Contains the user ID of the selected user ID that has power to issue invites, this will
// get populated into the "join_authorised_via_users_server" content in the membership
AuthorisedVia string `json:"authorised_via,omitempty"`
}
// MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode. // MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) { func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) {
se := make(map[string]string) se := make(map[string]string)

View file

@ -33,6 +33,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -75,6 +76,11 @@ func (r *Inputer) processRoomEvent(
default: default:
} }
span, ctx := opentracing.StartSpanFromContext(ctx, "processRoomEvent")
span.SetTag("room_id", input.Event.RoomID())
span.SetTag("event_id", input.Event.EventID())
defer span.Finish()
// Measure how long it takes to process this event. // Measure how long it takes to process this event.
started := time.Now() started := time.Now()
defer func() { defer func() {
@ -411,6 +417,9 @@ func (r *Inputer) fetchAuthEvents(
known map[string]*types.Event, known map[string]*types.Event,
servers []gomatrixserverlib.ServerName, servers []gomatrixserverlib.ServerName,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "fetchAuthEvents")
defer span.Finish()
unknown := map[string]struct{}{} unknown := map[string]struct{}{}
authEventIDs := event.AuthEventIDs() authEventIDs := event.AuthEventIDs()
if len(authEventIDs) == 0 { if len(authEventIDs) == 0 {
@ -526,6 +535,9 @@ func (r *Inputer) calculateAndSetState(
event *gomatrixserverlib.Event, event *gomatrixserverlib.Event,
isRejected bool, isRejected bool,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "calculateAndSetState")
defer span.Finish()
var succeeded bool var succeeded bool
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
if err != nil { if err != nil {
@ -535,8 +547,6 @@ func (r *Inputer) calculateAndSetState(
roomState := state.NewStateResolution(updater, roomInfo) roomState := state.NewStateResolution(updater, roomInfo)
if input.HasState { if input.HasState {
stateAtEvent.Overwrite = true
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
@ -549,8 +559,6 @@ func (r *Inputer) calculateAndSetState(
return fmt.Errorf("updater.AddState: %w", err) return fmt.Errorf("updater.AddState: %w", err)
} }
} else { } else {
stateAtEvent.Overwrite = false
// We haven't been told what the state at the event is so we need to calculate it from the prev_events // We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, isRejected); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, isRejected); err != nil {
return fmt.Errorf("roomState.CalculateAndStoreStateBeforeEvent: %w", err) return fmt.Errorf("roomState.CalculateAndStoreStateBeforeEvent: %w", err)

View file

@ -27,6 +27,8 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/sirupsen/logrus"
) )
// updateLatestEvents updates the list of latest events for this room in the database and writes the // updateLatestEvents updates the list of latest events for this room in the database and writes the
@ -55,6 +57,9 @@ func (r *Inputer) updateLatestEvents(
transactionID *api.TransactionID, transactionID *api.TransactionID,
rewritesState bool, rewritesState bool,
) (err error) { ) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "updateLatestEvents")
defer span.Finish()
var succeeded bool var succeeded bool
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
if err != nil { if err != nil {
@ -101,8 +106,8 @@ type latestEventsUpdater struct {
// The eventID of the event that was processed before this one. // The eventID of the event that was processed before this one.
lastEventIDSent string lastEventIDSent string
// The latest events in the room after processing this event. // The latest events in the room after processing this event.
oldLatest []types.StateAtEventAndReference oldLatest types.StateAtEventAndReferences
latest []types.StateAtEventAndReference latest types.StateAtEventAndReferences
// The state entries removed from and added to the current state of the // The state entries removed from and added to the current state of the
// room as a result of processing this event. They are sorted lists. // room as a result of processing this event. They are sorted lists.
removed []types.StateEntry removed []types.StateEntry
@ -125,7 +130,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// state snapshot from somewhere else, e.g. a federated room join, // state snapshot from somewhere else, e.g. a federated room join,
// then start with an empty set - none of the forward extremities // then start with an empty set - none of the forward extremities
// that we knew about before matter anymore. // that we knew about before matter anymore.
u.oldLatest = []types.StateAtEventAndReference{} u.oldLatest = types.StateAtEventAndReferences{}
if !u.rewritesState { if !u.rewritesState {
u.oldStateNID = u.updater.CurrentStateSnapshotNID() u.oldStateNID = u.updater.CurrentStateSnapshotNID()
u.oldLatest = u.updater.LatestEvents() u.oldLatest = u.updater.LatestEvents()
@ -199,13 +204,16 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
} }
func (u *latestEventsUpdater) latestState() error { func (u *latestEventsUpdater) latestState() error {
span, ctx := opentracing.StartSpanFromContext(u.ctx, "processEventWithMissingState")
defer span.Finish()
var err error var err error
roomState := state.NewStateResolution(u.updater, u.roomInfo) roomState := state.NewStateResolution(u.updater, u.roomInfo)
// Work out if the state at the extremities has actually changed // Work out if the state at the extremities has actually changed
// or not. If they haven't then we won't bother doing all of the // or not. If they haven't then we won't bother doing all of the
// hard work. // hard work.
if u.event.StateKey() == nil { if !u.stateAtEvent.IsStateEvent() {
stateChanged := false stateChanged := false
oldStateNIDs := make([]types.StateSnapshotNID, 0, len(u.oldLatest)) oldStateNIDs := make([]types.StateSnapshotNID, 0, len(u.oldLatest))
newStateNIDs := make([]types.StateSnapshotNID, 0, len(u.latest)) newStateNIDs := make([]types.StateSnapshotNID, 0, len(u.latest))
@ -233,54 +241,51 @@ func (u *latestEventsUpdater) latestState() error {
} }
} }
// Take the old set of extremities and the new set of extremities and // Get a list of the current latest events. This may or may not
// mash them together into a list. This may or may not include the new event // include the new event from the input path, depending on whether
// from the input path, depending on whether it became a forward extremity // it is a forward extremity or not.
// or not. We'll then run state resolution across all of them to determine latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
// the new current state of the room. Including the old extremities here for i := range u.latest {
// ensures that new forward extremities with bad state snapshots (from latestStateAtEvents[i] = u.latest[i].StateAtEvent
// possible malicious actors) can't completely corrupt the room state
// away from what it was before.
combinedExtremities := types.StateAtEventAndReferences(append(u.oldLatest, u.latest...))
combinedExtremities = combinedExtremities[:util.SortAndUnique(combinedExtremities)]
latestStateAtEvents := make([]types.StateAtEvent, len(combinedExtremities))
for i := range combinedExtremities {
latestStateAtEvents[i] = combinedExtremities[i].StateAtEvent
} }
// Takes the NIDs of the latest events and creates a state snapshot // Takes the NIDs of the latest events and creates a state snapshot
// of the state after the events. The snapshot state will be resolved // of the state after the events. The snapshot state will be resolved
// using the correct state resolution algorithm for the room. // using the correct state resolution algorithm for the room.
u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents(
u.ctx, latestStateAtEvents, ctx, latestStateAtEvents,
) )
if err != nil { if err != nil {
return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
} }
// If we are overwriting the state then we should make sure that we
// don't send anything out over federation again, it will very likely
// be a repeat.
if u.stateAtEvent.Overwrite {
u.sendAsServer = ""
}
// Now that we have a new state snapshot based on the latest events, // Now that we have a new state snapshot based on the latest events,
// we can compare that new snapshot to the previous one and see what // we can compare that new snapshot to the previous one and see what
// has changed. This gives us one list of removed state events and // has changed. This gives us one list of removed state events and
// another list of added ones. Replacing a value for a state-key tuple // another list of added ones. Replacing a value for a state-key tuple
// will result one removed (the old event) and one added (the new event). // will result one removed (the old event) and one added (the new event).
u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots( u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.oldStateNID, u.newStateNID, ctx, u.oldStateNID, u.newStateNID,
) )
if err != nil { if err != nil {
return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err) return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err)
} }
if removed := len(u.removed) - len(u.added); removed > 0 {
logrus.WithFields(logrus.Fields{
"event_id": u.event.EventID(),
"room_id": u.event.RoomID(),
"old_state_nid": u.oldStateNID,
"new_state_nid": u.newStateNID,
"old_latest": u.oldLatest.EventIDs(),
"new_latest": u.latest.EventIDs(),
}).Errorf("Unexpected state deletion (removing %d events)", removed)
}
// Also work out the state before the event removes and the event // Also work out the state before the event removes and the event
// adds. // adds.
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
) )
if err != nil { if err != nil {
return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err) return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err)
@ -296,6 +301,9 @@ func (u *latestEventsUpdater) calculateLatest(
newEvent *gomatrixserverlib.Event, newEvent *gomatrixserverlib.Event,
newStateAndRef types.StateAtEventAndReference, newStateAndRef types.StateAtEventAndReference,
) (bool, error) { ) (bool, error) {
span, _ := opentracing.StartSpanFromContext(u.ctx, "calculateLatest")
defer span.Finish()
// First of all, get a list of all of the events in our current // First of all, get a list of all of the events in our current
// set of forward extremities. // set of forward extremities.
existingRefs := make(map[string]*types.StateAtEventAndReference) existingRefs := make(map[string]*types.StateAtEventAndReference)

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
) )
// updateMembership updates the current membership and the invites for each // updateMembership updates the current membership and the invites for each
@ -34,6 +35,9 @@ func (r *Inputer) updateMemberships(
updater *shared.RoomUpdater, updater *shared.RoomUpdater,
removed, added []types.StateEntry, removed, added []types.StateEntry,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "updateMemberships")
defer span.Finish()
changes := membershipChanges(removed, added) changes := membershipChanges(removed, added)
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
for _, change := range changes { for _, change := range changes {

View file

@ -15,6 +15,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -59,6 +60,9 @@ type missingStateReq struct {
func (t *missingStateReq) processEventWithMissingState( func (t *missingStateReq) processEventWithMissingState(
ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion,
) (*parsedRespState, error) { ) (*parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "processEventWithMissingState")
defer span.Finish()
// We are missing the previous events for this events. // We are missing the previous events for this events.
// This means that there is a gap in our view of the history of the // This means that there is a gap in our view of the history of the
// room. There two ways that we can handle such a gap: // room. There two ways that we can handle such a gap:
@ -235,6 +239,9 @@ func (t *missingStateReq) processEventWithMissingState(
} }
func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupResolvedStateBeforeEvent")
defer span.Finish()
type respState struct { type respState struct {
// A snapshot is considered trustworthy if it came from our own roomserver. // A snapshot is considered trustworthy if it came from our own roomserver.
// That's because the state will have been through state resolution once // That's because the state will have been through state resolution once
@ -310,6 +317,9 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
// added into the mix. // added into the mix.
func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) { func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEvent")
defer span.Finish()
// try doing all this locally before we resort to querying federation // try doing all this locally before we resort to querying federation
respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID)
if respState != nil { if respState != nil {
@ -361,6 +371,9 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixs
} }
func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState { func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEventLocally")
defer span.Finish()
var res parsedRespState var res parsedRespState
roomInfo, err := t.db.RoomInfo(ctx, roomID) roomInfo, err := t.db.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
@ -435,12 +448,17 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room
// the server supports. // the server supports.
func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (
*parsedRespState, error) { *parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateBeforeEvent")
defer span.Finish()
// Attempt to fetch the missing state using /state_ids and /events // Attempt to fetch the missing state using /state_ids and /events
return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion)
} }
func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) { func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "resolveStatesAndCheck")
defer span.Finish()
var authEventList []*gomatrixserverlib.Event var authEventList []*gomatrixserverlib.Event
var stateEventList []*gomatrixserverlib.Event var stateEventList []*gomatrixserverlib.Event
for _, state := range states { for _, state := range states {
@ -484,6 +502,9 @@ retryAllowedState:
// get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject,
// without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events
func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "getMissingEvents")
defer span.Finish()
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID)
if err != nil { if err != nil {
@ -608,6 +629,9 @@ func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserve
func (t *missingStateReq) lookupMissingStateViaState( func (t *missingStateReq) lookupMissingStateViaState(
ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (respState *parsedRespState, err error) { ) (respState *parsedRespState, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState")
defer span.Finish()
state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
@ -637,6 +661,9 @@ func (t *missingStateReq) lookupMissingStateViaState(
func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
*parsedRespState, error) { *parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaStateIDs")
defer span.Finish()
util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID)
// fetch the state event IDs at the time of the event // fetch the state event IDs at the time of the event
stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID) stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID)
@ -799,6 +826,9 @@ func (t *missingStateReq) createRespStateFromStateIDs(
} }
func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupEvent")
defer span.Finish()
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) events, err := t.db.EventsFromIDs(ctx, []string{missingEventID})

View file

@ -24,6 +24,7 @@ import (
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
fsAPI "github.com/matrix-org/dendrite/federationapi/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
rsAPI "github.com/matrix-org/dendrite/roomserver/api" rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/input"
@ -160,6 +161,7 @@ func (r *Joiner) performJoinRoomByAlias(
} }
// TODO: Break this function up a bit // TODO: Break this function up a bit
// nolint:gocyclo
func (r *Joiner) performJoinRoomByID( func (r *Joiner) performJoinRoomByID(
ctx context.Context, ctx context.Context,
req *rsAPI.PerformJoinRequest, req *rsAPI.PerformJoinRequest,
@ -210,6 +212,11 @@ func (r *Joiner) performJoinRoomByID(
req.Content = map[string]interface{}{} req.Content = map[string]interface{}{}
} }
req.Content["membership"] = gomatrixserverlib.Join req.Content["membership"] = gomatrixserverlib.Join
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil {
return "", "", aerr
} else if authorisedVia != "" {
req.Content["join_authorised_via_users_server"] = authorisedVia
}
if err = eb.SetContent(req.Content); err != nil { if err = eb.SetContent(req.Content); err != nil {
return "", "", fmt.Errorf("eb.SetContent: %w", err) return "", "", fmt.Errorf("eb.SetContent: %w", err)
} }
@ -350,6 +357,33 @@ func (r *Joiner) performFederatedJoinRoomByID(
return fedRes.JoinedVia, nil return fedRes.JoinedVia, nil
} }
func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
ctx context.Context,
joinReq *rsAPI.PerformJoinRequest,
) (string, error) {
req := &api.QueryRestrictedJoinAllowedRequest{
UserID: joinReq.UserID,
RoomID: joinReq.RoomIDOrAlias,
}
res := &api.QueryRestrictedJoinAllowedResponse{}
if err := r.Queryer.QueryRestrictedJoinAllowed(ctx, req, res); err != nil {
return "", fmt.Errorf("r.Queryer.QueryRestrictedJoinAllowed: %w", err)
}
if !res.Restricted {
return "", nil
}
if !res.Resident {
return "", nil
}
if !res.Allowed {
return "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNotAllowed,
Msg: fmt.Sprintf("The join to room %s was not allowed.", joinReq.RoomIDOrAlias),
}
}
return res.AuthorisedVia, nil
}
func buildEvent( func buildEvent(
ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder,
) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { ) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) {

View file

@ -105,13 +105,13 @@ func (r *Upgrader) performRoomUpgrade(
return "", pErr return "", pErr
} }
// 5. Send the tombstone event to the old room (must do this before we set the new canonical_alias) // Send the setup events to the new room
if pErr = r.sendHeaderedEvent(ctx, tombstoneEvent); pErr != nil { if pErr = r.sendInitialEvents(ctx, evTime, userID, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil {
return "", pErr return "", pErr
} }
// Send the setup events to the new room // 5. Send the tombstone event to the old room
if pErr = r.sendInitialEvents(ctx, evTime, userID, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { if pErr = r.sendHeaderedEvent(ctx, tombstoneEvent, string(r.Cfg.Matrix.ServerName)); pErr != nil {
return "", pErr return "", pErr
} }
@ -147,7 +147,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error() util.GetLogger(ctx).WithError(err).Error()
return nil, &api.PerformError{ return nil, &api.PerformError{
Msg: "powerLevel event was not actually a power level event", Msg: "Power level event was invalid or malformed",
} }
} }
return powerLevelContent, nil return powerLevelContent, nil
@ -182,7 +182,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
return resErr return resErr
} }
} else { } else {
if resErr = r.sendHeaderedEvent(ctx, restrictedPowerLevelsHeadered); resErr != nil { if resErr = r.sendHeaderedEvent(ctx, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil {
return resErr return resErr
} }
} }
@ -198,7 +198,7 @@ func moveLocalAliases(ctx context.Context,
aliasRes := api.GetAliasesForRoomIDResponse{} aliasRes := api.GetAliasesForRoomIDResponse{}
if err = URSAPI.GetAliasesForRoomID(ctx, &aliasReq, &aliasRes); err != nil { if err = URSAPI.GetAliasesForRoomID(ctx, &aliasReq, &aliasRes); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "Could not get aliases for old room", Msg: fmt.Sprintf("Failed to get old room aliases: %s", err),
} }
} }
@ -207,7 +207,7 @@ func moveLocalAliases(ctx context.Context,
removeAliasRes := api.RemoveRoomAliasResponse{} removeAliasRes := api.RemoveRoomAliasResponse{}
if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "api.RemoveRoomAlias failed", Msg: fmt.Sprintf("Failed to remove old room alias: %s", err),
} }
} }
@ -215,7 +215,7 @@ func moveLocalAliases(ctx context.Context,
setAliasRes := api.SetRoomAliasResponse{} setAliasRes := api.SetRoomAliasResponse{}
if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "api.SetRoomAlias failed", Msg: fmt.Sprintf("Failed to set new room alias: %s", err),
} }
} }
} }
@ -253,7 +253,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
return resErr return resErr
} }
} else { } else {
if resErr = r.sendHeaderedEvent(ctx, emptyCanonicalAliasEvent); resErr != nil { if resErr = r.sendHeaderedEvent(ctx, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil {
return resErr return resErr
} }
} }
@ -509,7 +509,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
err = builder.SetContent(e.Content) err = builder.SetContent(e.Content)
if err != nil { if err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "builder.SetContent failed", Msg: fmt.Sprintf("Failed to set content of new %q event: %s", builder.Type, err),
} }
} }
if i > 0 { if i > 0 {
@ -519,13 +519,13 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
event, err = r.buildEvent(&builder, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) event, err = r.buildEvent(&builder, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion))
if err != nil { if err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "buildEvent failed", Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err),
} }
} }
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "gomatrixserverlib.Allowed failed", Msg: fmt.Sprintf("Failed to auth new %q event: %s", builder.Type, err),
} }
} }
@ -534,7 +534,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
err = authEvents.AddEvent(event) err = authEvents.AddEvent(event)
if err != nil { if err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "authEvents.AddEvent failed", Msg: fmt.Sprintf("Failed to add new %q event to auth set: %s", builder.Type, err),
} }
} }
} }
@ -550,7 +550,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
} }
if err = api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { if err = api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "api.SendInputRoomEvents failed", Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err),
} }
} }
return nil return nil
@ -582,7 +582,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
err := builder.SetContent(event.Content) err := builder.SetContent(event.Content)
if err != nil { if err != nil {
return nil, &api.PerformError{ return nil, &api.PerformError{
Msg: "builder.SetContent failed", Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err),
} }
} }
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
@ -607,7 +607,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
} }
} else if err != nil { } else if err != nil {
return nil, &api.PerformError{ return nil, &api.PerformError{
Msg: "eventutil.BuildEvent failed", Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err),
} }
} }
// check to see if this user can perform this operation // check to see if this user can perform this operation
@ -619,7 +619,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
if err = gomatrixserverlib.Allowed(headeredEvent.Event, &provider); err != nil { if err = gomatrixserverlib.Allowed(headeredEvent.Event, &provider); err != nil {
return nil, &api.PerformError{ return nil, &api.PerformError{
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
Msg: err.Error(), // TODO: Is this error string comprehensible to the client? Msg: fmt.Sprintf("Failed to auth new %q event: %s", builder.Type, err), // TODO: Is this error string comprehensible to the client?
} }
} }
@ -666,17 +666,18 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC
func (r *Upgrader) sendHeaderedEvent( func (r *Upgrader) sendHeaderedEvent(
ctx context.Context, ctx context.Context,
headeredEvent *gomatrixserverlib.HeaderedEvent, headeredEvent *gomatrixserverlib.HeaderedEvent,
sendAsServer string,
) *api.PerformError { ) *api.PerformError {
var inputs []api.InputRoomEvent var inputs []api.InputRoomEvent
inputs = append(inputs, api.InputRoomEvent{ inputs = append(inputs, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: headeredEvent, Event: headeredEvent,
Origin: r.Cfg.Matrix.ServerName, Origin: r.Cfg.Matrix.ServerName,
SendAsServer: api.DoNotSendToOtherServers, SendAsServer: sendAsServer,
}) })
if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: "api.SendInputRoomEvents failed", Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err),
} }
} }
@ -703,7 +704,7 @@ func (r *Upgrader) buildEvent(
r.Cfg.Matrix.PrivateKey, roomVersion, r.Cfg.Matrix.PrivateKey, roomVersion,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot build event %s : Builder failed to build. %w", builder.Type, err) return nil, err
} }
return event, nil return event, nil
} }

View file

@ -16,6 +16,7 @@ package query
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -757,3 +758,131 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
res.AuthChain = hchain res.AuthChain = hchain
return nil return nil
} }
// nolint:gocyclo
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse) error {
// Look up if we know anything about the room. If it doesn't exist
// or is a stub entry then we can't do anything.
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err)
}
if roomInfo == nil || roomInfo.IsStub {
return nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID)
}
// If the room version doesn't allow restricted joins then don't
// try to process any further.
allowRestrictedJoins, err := roomInfo.RoomVersion.MayAllowRestrictedJoinsInEventAuth()
if err != nil {
return fmt.Errorf("roomInfo.RoomVersion.AllowRestrictedJoinsInEventAuth: %w", err)
} else if !allowRestrictedJoins {
return nil
}
// Start off by populating the "resident" flag in the response. If we
// come across any rooms in the request that are missing, we will unset
// the flag.
res.Resident = true
// Get the join rules to work out if the join rule is "restricted".
joinRulesEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, gomatrixserverlib.MRoomJoinRules, "")
if err != nil {
return fmt.Errorf("r.DB.GetStateEvent: %w", err)
}
if joinRulesEvent == nil {
return nil
}
var joinRules gomatrixserverlib.JoinRuleContent
if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
// If the join rule isn't "restricted" then there's nothing more to do.
res.Restricted = joinRules.JoinRule == gomatrixserverlib.Restricted
if !res.Restricted {
return nil
}
// If the user is already invited to the room then the join is allowed
// but we don't specify an authorised via user, since the event auth
// will allow the join anyway.
var pending bool
if pending, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil {
return fmt.Errorf("helpers.IsInvitePending: %w", err)
} else if pending {
res.Allowed = true
return nil
}
// We need to get the power levels content so that we can determine which
// users in the room are entitled to issue invites. We need to use one of
// these users as the authorising user.
powerLevelsEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, gomatrixserverlib.MRoomPowerLevels, "")
if err != nil {
return fmt.Errorf("r.DB.GetStateEvent: %w", err)
}
var powerLevels gomatrixserverlib.PowerLevelContent
if err = json.Unmarshal(powerLevelsEvent.Content(), &powerLevels); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
// Step through the join rules and see if the user matches any of them.
for _, rule := range joinRules.Allow {
// We only understand "m.room_membership" rules at this point in
// time, so skip any rule that doesn't match those.
if rule.Type != gomatrixserverlib.MRoomMembership {
continue
}
// See if the room exists. If it doesn't exist or if it's a stub
// room entry then we can't check memberships.
targetRoomInfo, err := r.DB.RoomInfo(ctx, rule.RoomID)
if err != nil || targetRoomInfo == nil || targetRoomInfo.IsStub {
res.Resident = false
continue
}
// First of all work out if *we* are still in the room, otherwise
// it's possible that the memberships will be out of date.
isIn, err := r.DB.GetLocalServerInRoom(ctx, targetRoomInfo.RoomNID)
if err != nil || !isIn {
// If we aren't in the room, we can no longer tell if the room
// memberships are up-to-date.
res.Resident = false
continue
}
// At this point we're happy that we are in the room, so now let's
// see if the target user is in the room.
_, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID)
if err != nil {
continue
}
// If the user is not in the room then we will skip them.
if !isIn {
continue
}
// The user is in the room, so now we will need to authorise the
// join using the user ID of one of our own users in the room. Pick
// one.
joinNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, targetRoomInfo.RoomNID, true, true)
if err != nil || len(joinNIDs) == 0 {
// There should always be more than one join NID at this point
// because we are gated behind GetLocalServerInRoom, but y'know,
// sometimes strange things happen.
continue
}
// For each of the joined users, let's see if we can get a valid
// membership event.
for _, joinNID := range joinNIDs {
events, err := r.DB.Events(ctx, []types.EventNID{joinNID})
if err != nil || len(events) != 1 {
continue
}
event := events[0]
if event.Type() != gomatrixserverlib.MRoomMember || event.StateKey() == nil {
continue // shouldn't happen
}
// Only users that have the power to invite should be chosen.
if powerLevels.UserLevel(*event.StateKey()) < powerLevels.Invite {
continue
}
res.Resident = true
res.Allowed = true
res.AuthorisedVia = *event.StateKey()
return nil
}
}
return nil
}

View file

@ -61,6 +61,7 @@ const (
RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers"
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
) )
type httpRoomserverInternalAPI struct { type httpRoomserverInternalAPI struct {
@ -557,6 +558,16 @@ func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed(
ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRestrictedJoinAllowed")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryRestrictedJoinAllowed
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error { func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget") span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget")
defer span.Finish() defer span.Finish()

View file

@ -472,4 +472,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(RoomserverQueryRestrictedJoinAllowed,
httputil.MakeInternalAPI("queryRestrictedJoinAllowed", func(req *http.Request) util.JSONResponse {
request := api.QueryRestrictedJoinAllowedRequest{}
response := api.QueryRestrictedJoinAllowedResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryRestrictedJoinAllowed(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -20,9 +20,11 @@ import (
"context" "context"
"fmt" "fmt"
"sort" "sort"
"sync"
"time" "time"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -39,6 +41,7 @@ type StateResolutionStorage interface {
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
} }
type StateResolution struct { type StateResolution struct {
@ -61,6 +64,9 @@ func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) Sta
func (v *StateResolution) LoadStateAtSnapshot( func (v *StateResolution) LoadStateAtSnapshot(
ctx context.Context, stateNID types.StateSnapshotNID, ctx context.Context, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshot")
defer span.Finish()
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil { if err != nil {
return nil, err return nil, err
@ -99,6 +105,9 @@ func (v *StateResolution) LoadStateAtSnapshot(
func (v *StateResolution) LoadStateAtEvent( func (v *StateResolution) LoadStateAtEvent(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent")
defer span.Finish()
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil { if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err) return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
@ -121,6 +130,9 @@ func (v *StateResolution) LoadStateAtEvent(
func (v *StateResolution) LoadCombinedStateAfterEvents( func (v *StateResolution) LoadCombinedStateAfterEvents(
ctx context.Context, prevStates []types.StateAtEvent, ctx context.Context, prevStates []types.StateAtEvent,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadCombinedStateAfterEvents")
defer span.Finish()
stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
for i, state := range prevStates { for i, state := range prevStates {
stateNIDs[i] = state.BeforeStateSnapshotNID stateNIDs[i] = state.BeforeStateSnapshotNID
@ -193,6 +205,9 @@ func (v *StateResolution) LoadCombinedStateAfterEvents(
func (v *StateResolution) DifferenceBetweeenStateSnapshots( func (v *StateResolution) DifferenceBetweeenStateSnapshots(
ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID,
) (removed, added []types.StateEntry, err error) { ) (removed, added []types.StateEntry, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.DifferenceBetweeenStateSnapshots")
defer span.Finish()
if oldStateNID == newStateNID { if oldStateNID == newStateNID {
// If the snapshot NIDs are the same then nothing has changed // If the snapshot NIDs are the same then nothing has changed
return nil, nil, nil return nil, nil, nil
@ -254,6 +269,9 @@ func (v *StateResolution) LoadStateAtSnapshotForStringTuples(
stateNID types.StateSnapshotNID, stateNID types.StateSnapshotNID,
stateKeyTuples []gomatrixserverlib.StateKeyTuple, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples")
defer span.Finish()
numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples)
if err != nil { if err != nil {
return nil, err return nil, err
@ -268,6 +286,9 @@ func (v *StateResolution) stringTuplesToNumericTuples(
ctx context.Context, ctx context.Context,
stringTuples []gomatrixserverlib.StateKeyTuple, stringTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateKeyTuple, error) { ) ([]types.StateKeyTuple, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.stringTuplesToNumericTuples")
defer span.Finish()
eventTypes := make([]string, len(stringTuples)) eventTypes := make([]string, len(stringTuples))
stateKeys := make([]string, len(stringTuples)) stateKeys := make([]string, len(stringTuples))
for i := range stringTuples { for i := range stringTuples {
@ -310,6 +331,9 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples(
stateNID types.StateSnapshotNID, stateNID types.StateSnapshotNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples")
defer span.Finish()
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil { if err != nil {
return nil, err return nil, err
@ -358,6 +382,9 @@ func (v *StateResolution) LoadStateAfterEventsForStringTuples(
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAfterEventsForStringTuples")
defer span.Finish()
numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples)
if err != nil { if err != nil {
return nil, err return nil, err
@ -370,6 +397,9 @@ func (v *StateResolution) loadStateAfterEventsForNumericTuples(
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAfterEventsForNumericTuples")
defer span.Finish()
if len(prevStates) == 1 { if len(prevStates) == 1 {
// Fast path for a single event. // Fast path for a single event.
prevState := prevStates[0] prevState := prevStates[0]
@ -542,6 +572,9 @@ func (v *StateResolution) CalculateAndStoreStateBeforeEvent(
event *gomatrixserverlib.Event, event *gomatrixserverlib.Event,
isRejected bool, isRejected bool,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent")
defer span.Finish()
// Load the state at the prev events. // Load the state at the prev events.
prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs()) prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs())
if err != nil { if err != nil {
@ -558,6 +591,9 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
ctx context.Context, ctx context.Context,
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateAfterEvents")
defer span.Finish()
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
if len(prevStates) == 0 { if len(prevStates) == 0 {
@ -630,6 +666,9 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents(
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
metrics calculateStateMetrics, metrics calculateStateMetrics,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents")
defer span.Finish()
state, algorithm, conflictLength, err := state, algorithm, conflictLength, err :=
v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
metrics.algorithm = algorithm metrics.algorithm = algorithm
@ -648,6 +687,9 @@ func (v *StateResolution) calculateStateAfterManyEvents(
ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, roomVersion gomatrixserverlib.RoomVersion,
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
) (state []types.StateEntry, algorithm string, conflictLength int, err error) { ) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateStateAfterManyEvents")
defer span.Finish()
var combined []types.StateEntry var combined []types.StateEntry
// Conflict resolution. // Conflict resolution.
// First stage: load the state after each of the prev events. // First stage: load the state after each of the prev events.
@ -659,15 +701,13 @@ func (v *StateResolution) calculateStateAfterManyEvents(
} }
// Collect all the entries with the same type and key together. // Collect all the entries with the same type and key together.
// We don't care about the order here because the conflict resolution // This is done so findDuplicateStateKeys can work in groups.
// algorithm doesn't depend on the order of the prev events. // We remove duplicates (same type, state key and event NID) too.
// Remove duplicate entires.
combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] combined = combined[:util.SortAndUnique(stateEntrySorter(combined))]
// Find the conflicts // Find the conflicts
conflicts := findDuplicateStateKeys(combined) if conflicts := findDuplicateStateKeys(combined); len(conflicts) > 0 {
conflictMap := stateEntryMap(conflicts)
if len(conflicts) > 0 {
conflictLength = len(conflicts) conflictLength = len(conflicts)
// 5) There are conflicting state events, for each conflict workout // 5) There are conflicting state events, for each conflict workout
@ -676,7 +716,7 @@ func (v *StateResolution) calculateStateAfterManyEvents(
// Work out which entries aren't conflicted. // Work out which entries aren't conflicted.
var notConflicted []types.StateEntry var notConflicted []types.StateEntry
for _, entry := range combined { for _, entry := range combined {
if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { if _, ok := conflictMap.lookup(entry.StateKeyTuple); !ok {
notConflicted = append(notConflicted, entry) notConflicted = append(notConflicted, entry)
} }
} }
@ -689,7 +729,7 @@ func (v *StateResolution) calculateStateAfterManyEvents(
return return
} }
algorithm = "full_state_with_conflicts" algorithm = "full_state_with_conflicts"
state = resolved[:util.SortAndUnique(stateEntrySorter(resolved))] state = resolved
} else { } else {
algorithm = "full_state_no_conflicts" algorithm = "full_state_no_conflicts"
// 6) There weren't any conflicts // 6) There weren't any conflicts
@ -702,6 +742,9 @@ func (v *StateResolution) resolveConflicts(
ctx context.Context, version gomatrixserverlib.RoomVersion, ctx context.Context, version gomatrixserverlib.RoomVersion,
notConflicted, conflicted []types.StateEntry, notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflicts")
defer span.Finish()
stateResAlgo, err := version.StateResAlgorithm() stateResAlgo, err := version.StateResAlgorithm()
if err != nil { if err != nil {
return nil, err return nil, err
@ -726,6 +769,8 @@ func (v *StateResolution) resolveConflictsV1(
ctx context.Context, ctx context.Context,
notConflicted, conflicted []types.StateEntry, notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV1")
defer span.Finish()
// Load the conflicted events // Load the conflicted events
conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted)
@ -789,6 +834,9 @@ func (v *StateResolution) resolveConflictsV2(
ctx context.Context, ctx context.Context,
notConflicted, conflicted []types.StateEntry, notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV2")
defer span.Finish()
estimate := len(conflicted) + len(notConflicted) estimate := len(conflicted) + len(notConflicted)
eventIDMap := make(map[string]types.StateEntry, estimate) eventIDMap := make(map[string]types.StateEntry, estimate)
@ -816,51 +864,47 @@ func (v *StateResolution) resolveConflictsV2(
authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3)
gotAuthEvents := make(map[string]struct{}, estimate*3) gotAuthEvents := make(map[string]struct{}, estimate*3)
authDifference := make([]*gomatrixserverlib.Event, 0, estimate) authDifference := make([]*gomatrixserverlib.Event, 0, estimate)
knownAuthEvents := make(map[string]types.Event, estimate*3)
// For each conflicted event, let's try and get the needed auth events. // For each conflicted event, let's try and get the needed auth events.
neededStateKeys := make([]string, 16) if err = func() error {
authEntries := make([]types.StateEntry, 16) span, sctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadAuthEvents")
for _, conflictedEvent := range conflictedEvents { defer span.Finish()
// Work out which auth events we need to load.
key := conflictedEvent.EventID()
needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{conflictedEvent})
// Find the numeric IDs for the necessary state keys. loader := authEventLoader{
neededStateKeys = neededStateKeys[:0] v: v,
neededStateKeys = append(neededStateKeys, needed.Member...) lookupFromDB: make([]string, 0, len(conflictedEvents)*3),
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) lookupFromMem: make([]string, 0, len(conflictedEvents)*3),
stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) lookedUpEvents: make([]types.Event, 0, len(conflictedEvents)*3),
if err != nil { eventMap: map[string]types.Event{},
return nil, err
} }
for _, conflictedEvent := range conflictedEvents {
// Work out which auth events we need to load.
key := conflictedEvent.EventID()
// Load the necessary auth events. // Store the newly found auth events in the auth set for this event.
tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) var authEventMap map[string]types.StateEntry
authEntries = authEntries[:0] authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, conflictedEvent, knownAuthEvents)
for _, tuple := range tuplesNeeded { if err != nil {
if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { return err
authEntries = append(authEntries, types.StateEntry{ }
StateKeyTuple: tuple, for k, v := range authEventMap {
EventNID: eventNID, eventIDMap[k] = v
}) }
}
} // Only add auth events into the authEvents slice once, otherwise the
// check for the auth difference can become expensive and produce
// Store the newly found auth events in the auth set for this event. // duplicate entries, which just waste memory and CPU time.
authSets[key], _, err = v.loadStateEvents(ctx, authEntries) for _, event := range authSets[key] {
if err != nil { if _, ok := gotAuthEvents[event.EventID()]; !ok {
return nil, err authEvents = append(authEvents, event)
} gotAuthEvents[event.EventID()] = struct{}{}
}
// Only add auth events into the authEvents slice once, otherwise the
// check for the auth difference can become expensive and produce
// duplicate entries, which just waste memory and CPU time.
for _, event := range authSets[key] {
if _, ok := gotAuthEvents[event.EventID()]; !ok {
authEvents = append(authEvents, event)
gotAuthEvents[event.EventID()] = struct{}{}
} }
} }
return nil
}(); err != nil {
return nil, err
} }
// Kill the reference to this so that the GC may pick it up, since we no // Kill the reference to this so that the GC may pick it up, since we no
@ -891,25 +935,35 @@ func (v *StateResolution) resolveConflictsV2(
// Look through all of the auth events that we've been given and work out if // Look through all of the auth events that we've been given and work out if
// there are any events which don't appear in all of the auth sets. If they // there are any events which don't appear in all of the auth sets. If they
// don't then we add them to the auth difference. // don't then we add them to the auth difference.
for _, event := range authEvents { func() {
if !isInAllAuthLists(event) { span, _ := opentracing.StartSpanFromContext(ctx, "isInAllAuthLists")
authDifference = append(authDifference, event) defer span.Finish()
for _, event := range authEvents {
if !isInAllAuthLists(event) {
authDifference = append(authDifference, event)
}
} }
} }()
// Resolve the conflicts. // Resolve the conflicts.
resolvedEvents := gomatrixserverlib.ResolveStateConflictsV2( resolvedEvents := func() []*gomatrixserverlib.Event {
conflictedEvents, span, _ := opentracing.StartSpanFromContext(ctx, "gomatrixserverlib.ResolveStateConflictsV2")
nonConflictedEvents, defer span.Finish()
authEvents,
authDifference, return gomatrixserverlib.ResolveStateConflictsV2(
) conflictedEvents,
nonConflictedEvents,
authEvents,
authDifference,
)
}()
// Map from the full events back to numeric state entries. // Map from the full events back to numeric state entries.
for _, resolvedEvent := range resolvedEvents { for _, resolvedEvent := range resolvedEvents {
entry, ok := eventIDMap[resolvedEvent.EventID()] entry, ok := eventIDMap[resolvedEvent.EventID()]
if !ok { if !ok {
panic(fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID())) return nil, fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID())
} }
notConflicted = append(notConflicted, entry) notConflicted = append(notConflicted, entry)
} }
@ -968,6 +1022,9 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E
func (v *StateResolution) loadStateEvents( func (v *StateResolution) loadStateEvents(
ctx context.Context, entries []types.StateEntry, ctx context.Context, entries []types.StateEntry,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateEvents")
defer span.Finish()
result := make([]*gomatrixserverlib.Event, 0, len(entries)) result := make([]*gomatrixserverlib.Event, 0, len(entries))
eventEntries := make([]types.StateEntry, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries))
eventNIDs := make([]types.EventNID, 0, len(entries)) eventNIDs := make([]types.EventNID, 0, len(entries))
@ -996,6 +1053,127 @@ func (v *StateResolution) loadStateEvents(
return result, eventIDMap, nil return result, eventIDMap, nil
} }
type authEventLoader struct {
sync.Mutex
v *StateResolution
lookupFromDB []string // scratch space
lookupFromMem []string // scratch space
lookedUpEvents []types.Event // scratch space
eventMap map[string]types.Event
}
// loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events.
func (l *authEventLoader) loadAuthEvents(
ctx context.Context, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
l.Lock()
defer l.Unlock()
authEvents := []types.Event{} // our returned list
included := map[string]struct{}{} // dedupes authEvents above
queue := event.AuthEventIDs()
for i := 0; i < len(queue); i++ {
// Reuse the same underlying memory, since it reduces the
// amount of allocations we make the more times we call
// loadAuthEvents.
l.lookupFromDB = l.lookupFromDB[:0]
l.lookupFromMem = l.lookupFromMem[:0]
l.lookedUpEvents = l.lookedUpEvents[:0]
// Separate out the list of events in the queue based on if
// we think we already know the event in memory or not.
for _, authEventID := range queue {
if _, ok := included[authEventID]; ok {
continue
}
if _, ok := eventMap[authEventID]; ok {
l.lookupFromMem = append(l.lookupFromMem, authEventID)
} else {
l.lookupFromDB = append(l.lookupFromDB, authEventID)
}
}
// If there's nothing to do, stop here.
if len(l.lookupFromDB) == 0 && len(l.lookupFromMem) == 0 {
break
}
// If we need to get events from the database, go and fetch
// those now.
if len(l.lookupFromDB) > 0 {
eventsFromDB, err := l.v.db.EventsFromIDs(ctx, l.lookupFromDB)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
}
l.lookedUpEvents = append(l.lookedUpEvents, eventsFromDB...)
for _, event := range eventsFromDB {
eventMap[event.EventID()] = event
}
}
// Fill in the gaps with events that we already have in memory.
if len(l.lookupFromMem) > 0 {
for _, eventID := range l.lookupFromMem {
l.lookedUpEvents = append(l.lookedUpEvents, eventMap[eventID])
}
}
// From the events that we've retrieved, work out which auth
// events to look up on the next iteration.
add := map[string]struct{}{}
for _, event := range l.lookedUpEvents {
authEvents = append(authEvents, event)
included[event.EventID()] = struct{}{}
for _, authEventID := range event.AuthEventIDs() {
if _, ok := included[authEventID]; ok {
continue
}
add[authEventID] = struct{}{}
}
}
for authEventID := range add {
queue = append(queue, authEventID)
}
}
authEventTypes := map[string]struct{}{}
authEventStateKeys := map[string]struct{}{}
for _, authEvent := range authEvents {
authEventTypes[authEvent.Type()] = struct{}{}
authEventStateKeys[*authEvent.StateKey()] = struct{}{}
}
lookupAuthEventTypes := make([]string, 0, len(authEventTypes))
lookupAuthEventStateKeys := make([]string, 0, len(authEventStateKeys))
for eventType := range authEventTypes {
lookupAuthEventTypes = append(lookupAuthEventTypes, eventType)
}
for eventStateKey := range authEventStateKeys {
lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey)
}
eventTypes, err := l.v.db.EventTypeNIDs(ctx, lookupAuthEventTypes)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err)
}
eventStateKeys, err := l.v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err)
}
stateEntryMap := map[string]types.StateEntry{}
for _, authEvent := range authEvents {
stateEntryMap[authEvent.EventID()] = types.StateEntry{
EventNID: authEvent.EventNID,
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypes[authEvent.Type()],
EventStateKeyNID: eventStateKeys[*authEvent.StateKey()],
},
}
}
nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents))
for _, authEvent := range authEvents {
nakedEvents = append(nakedEvents, authEvent.Event)
}
return nakedEvents, stateEntryMap, nil
}
// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. // findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
// Returns a sorted list of those state entries. // Returns a sorted list of those state entries.
func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry {

View file

@ -192,6 +192,10 @@ func (u *RoomUpdater) StateAtEventIDs(
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
} }
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
}
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
} }

View file

@ -85,7 +85,6 @@ func Test_EventsTable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
stateAtEvent := types.StateAtEvent{ stateAtEvent := types.StateAtEvent{
Overwrite: false,
BeforeStateSnapshotNID: types.StateSnapshotNID(stateSnapshot), BeforeStateSnapshotNID: types.StateSnapshotNID(stateSnapshot),
IsRejected: false, IsRejected: false,
StateEntry: types.StateEntry{ StateEntry: types.StateEntry{

View file

@ -173,10 +173,6 @@ func DeduplicateStateEntries(a []StateEntry) []StateEntry {
// StateAtEvent is the state before and after a matrix event. // StateAtEvent is the state before and after a matrix event.
type StateAtEvent struct { type StateAtEvent struct {
// Should this state overwrite the latest events and memberships of the room?
// This might be necessary when rejoining a federated room after a period of
// absence, as our state and latest events will be out of date.
Overwrite bool
// The state before the event. // The state before the event.
BeforeStateSnapshotNID StateSnapshotNID BeforeStateSnapshotNID StateSnapshotNID
// True if this StateEntry is rejected. State resolution should then treat this // True if this StateEntry is rejected. State resolution should then treat this
@ -214,6 +210,14 @@ func (s StateAtEventAndReferences) Swap(a, b int) {
s[a], s[b] = s[b], s[a] s[a], s[b] = s[b], s[a]
} }
func (s StateAtEventAndReferences) EventIDs() string {
strs := make([]string, 0, len(s))
for _, r := range s {
strs = append(strs, r.EventID)
}
return "[" + strings.Join(strs, " ") + "]"
}
// An Event is a gomatrixserverlib.Event with the numeric event ID attached. // An Event is a gomatrixserverlib.Event with the numeric event ID attached.
// It is when performing bulk event lookup in the database. // It is when performing bulk event lookup in the database.
type Event struct { type Event struct {

View file

@ -334,6 +334,10 @@ type RateLimiting struct {
// The cooloff period in milliseconds after a request before the "slot" // The cooloff period in milliseconds after a request before the "slot"
// is freed again // is freed again
CooloffMS int64 `yaml:"cooloff_ms"` CooloffMS int64 `yaml:"cooloff_ms"`
// A list of users that are exempt from rate limiting, i.e. if you want
// to run Mjolnir or other bots.
ExemptUserIDs []string `yaml:"exempt_user_ids"`
} }
func (r *RateLimiting) Verify(configErrs *ConfigErrors) { func (r *RateLimiting) Verify(configErrs *ConfigErrors) {

View file

@ -715,4 +715,7 @@ Presence can be set from sync
PUT /rooms/:room_id/redact/:event_id/:txn_id is idempotent PUT /rooms/:room_id/redact/:event_id/:txn_id is idempotent
Unnamed room comes with a name summary Unnamed room comes with a name summary
Named room comes with just joined member count summary Named room comes with just joined member count summary
Room summary only has 5 heroes Room summary only has 5 heroes
Setting state twice is idempotent
Joining room twice is idempotent
Inbound federation can return missing events for shared visibility

View file

@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'" "SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt

View file

@ -124,6 +124,23 @@ func Test_Accounts(t *testing.T) {
_, err = db.GetAccountByLocalpart(ctx, "unusename") _, err = db.GetAccountByLocalpart(ctx, "unusename")
assert.Error(t, err, "expected an error for non existent localpart") assert.Error(t, err, "expected an error for non existent localpart")
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
// if there's already a user without a localpart in the database
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser)
assert.NoError(t, err)
// test getting a numeric localpart, with an existing user without a localpart
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
assert.NoError(t, err)
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser)
assert.NoError(t, err)
// Now try to create a new guest user
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
assert.NoError(t, err)
}) })
} }