Remove pdu.UserID & pass in the required accessor

This commit is contained in:
Devon Hudson 2023-06-05 20:20:01 -06:00
parent 3582ec7de7
commit d7e41177f3
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
52 changed files with 355 additions and 136 deletions

View file

@ -181,7 +181,9 @@ func (s *OutputRoomEventConsumer) sendEvents(
// Create the transaction body. // Create the transaction body.
transaction, err := json.Marshal( transaction, err := json.Marshal(
ApplicationServiceTransaction{ ApplicationServiceTransaction{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll), Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}),
}, },
) )
if err != nil { if err != nil {
@ -234,7 +236,7 @@ func (s *appserviceState) backoffAndPause(err error) error {
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
user := "" user := ""
userID, err := event.UserID() userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil { if err == nil {
user = userID.String() user = userID.String()
} }

View file

@ -331,7 +331,9 @@ func generateSendEvent(
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(e.PDU, &provider); err != nil { if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client?

View file

@ -142,7 +142,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
for _, ev := range stateRes.StateEvents { for _, ev := range stateRes.StateEvents {
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll), synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}),
) )
} }
} else { } else {
@ -164,7 +166,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
for _, ev := range stateAfterRes.StateEvents { for _, ev := range stateAfterRes.StateEvents {
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll), synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}),
) )
} }
} }
@ -335,7 +339,9 @@ func OnIncomingStateTypeRequest(
} }
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll), ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}),
} }
var res interface{} var res interface{}

View file

@ -18,6 +18,7 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
// This is a utility for inspecting state snapshots and running state resolution // This is a utility for inspecting state snapshots and running state resolution
@ -182,7 +183,9 @@ func main() {
fmt.Println("Resolving state") fmt.Println("Resolving state")
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return roomserverDB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
},
) )
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -36,6 +36,10 @@ type fedRoomserverAPI struct {
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
} }
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
// PerformJoin will call this function // PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if f.inputRoomEvents == nil { if f.inputRoomEvents == nil {

View file

@ -156,15 +156,20 @@ func (r *FederationInternalAPI) performJoinUsingServer(
} }
joinInput := gomatrixserverlib.PerformJoinInput{ joinInput := gomatrixserverlib.PerformJoinInput{
UserID: user, UserID: user,
RoomID: room, RoomID: room,
ServerName: serverName, ServerName: serverName,
Content: content, Content: content,
Unsigned: unsigned, Unsigned: unsigned,
PrivateKey: r.cfg.Matrix.PrivateKey, PrivateKey: r.cfg.Matrix.PrivateKey,
KeyID: r.cfg.Matrix.KeyID, KeyID: r.cfg.Matrix.KeyID,
KeyRing: r.keyRing, KeyRing: r.keyRing,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName), EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}),
UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
},
} }
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
@ -358,8 +363,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
userIDProvider := func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName), ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider,
) )
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
@ -509,7 +517,7 @@ func (r *FederationInternalAPI) SendInvite(
event gomatrixserverlib.PDU, event gomatrixserverlib.PDU,
strippedState []gomatrixserverlib.InviteStrippedState, strippedState []gomatrixserverlib.InviteStrippedState,
) (gomatrixserverlib.PDU, error) { ) (gomatrixserverlib.PDU, error) {
inviter, err := event.UserID() inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -640,6 +648,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error {
func federatedEventProvider( func federatedEventProvider(
ctx context.Context, federation fclient.FederationClient, ctx context.Context, federation fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
userIDForSender spec.UserIDForSender,
) gomatrixserverlib.EventProvider { ) gomatrixserverlib.EventProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
@ -689,7 +698,7 @@ func federatedEventProvider(
} }
// Check the signatures of the event. // Check the signatures of the event.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
} }

View file

@ -147,15 +147,18 @@ func MakeJoin(
} }
input := gomatrixserverlib.HandleMakeJoinInput{ input := gomatrixserverlib.HandleMakeJoinInput{
Context: httpReq.Context(), Context: httpReq.Context(),
UserID: userID, UserID: userID,
RoomID: roomID, RoomID: roomID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RemoteVersions: remoteVersions, RemoteVersions: remoteVersions,
RequestOrigin: request.Origin(), RequestOrigin: request.Origin(),
LocalServerName: cfg.Matrix.ServerName, LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerInRoom: res.RoomExists && res.IsInRoom,
RoomQuerier: &roomQuerier, RoomQuerier: &roomQuerier,
UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID)
},
BuildEventTemplate: createJoinTemplate, BuildEventTemplate: createJoinTemplate,
} }
response, internalErr := gomatrixserverlib.HandleMakeJoin(input) response, internalErr := gomatrixserverlib.HandleMakeJoin(input)

View file

@ -223,7 +223,7 @@ func SendLeave(
// Check that the sender belongs to the server that is sending us // Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender // the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both. // and the state key are equal so we don't need to check both.
sender, err := event.UserID() sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID())
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

4
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16
@ -34,7 +34,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_golang v1.13.0
github.com/sirupsen/logrus v1.9.2 github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.2 github.com/stretchr/testify v1.8.2
github.com/tidwall/gjson v1.14.4 github.com/tidwall/gjson v1.14.4
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5

8
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced h1:pbCM+nno+r2wW3jwxP65xmkzk6008CdMNZaOWYBwB1c= github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30 h1:G+Do1UoWazY0Fetq+eAX1h1+fimf19NGGyaS86hWg8s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=

View file

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
// A RuleSetEvaluator encapsulates context to evaluate an event // A RuleSetEvaluator encapsulates context to evaluate an event
@ -53,7 +54,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat
// MatchEvent returns the first matching rule. Returns nil if there // MatchEvent returns the first matching rule. Returns nil if there
// was no match rule. // was no match rule.
func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, error) { func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) (*Rule, error) {
// TODO: server-default rules have lower priority than user rules, // TODO: server-default rules have lower priority than user rules,
// but they are stored together with the user rules. It's a bit // but they are stored together with the user rules. It's a bit
// unclear what the specification (11.14.1.4 Predefined rules) // unclear what the specification (11.14.1.4 Predefined rules)
@ -68,7 +69,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err
if rule.Default != defRules { if rule.Default != defRules {
continue continue
} }
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec, userIDForSender)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,7 +84,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err
return nil, nil return nil, nil
} }
func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) { func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext, userIDForSender spec.UserIDForSender) (bool, error) {
if !rule.Enabled { if !rule.Enabled {
return false, nil return false, nil
} }
@ -114,7 +115,7 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
case SenderKind: case SenderKind:
userID := "" userID := ""
sender, err := event.UserID() sender, err := userIDForSender(event.RoomID(), event.SenderID())
if err == nil { if err == nil {
userID = sender.String() userID = sender.String()
} }

View file

@ -5,8 +5,13 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestRuleSetEvaluatorMatchEvent(t *testing.T) { func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
ev := mustEventFromJSON(t, `{}`) ev := mustEventFromJSON(t, `{}`)
defaultEnabled := &Rule{ defaultEnabled := &Rule{
@ -45,7 +50,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {
rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet) rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet)
got, err := rse.MatchEvent(tst.Event) got, err := rse.MatchEvent(tst.Event, UserIDForSender)
if err != nil { if err != nil {
t.Fatalf("MatchEvent failed: %v", err) t.Fatalf("MatchEvent failed: %v", err)
} }
@ -90,7 +95,7 @@ func TestRuleMatches(t *testing.T) {
} }
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil, UserIDForSender)
if err != nil { if err != nil {
t.Fatalf("ruleMatches failed: %v", err) t.Fatalf("ruleMatches failed: %v", err)
} }

View file

@ -167,7 +167,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
} }
continue continue
} }
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return t.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = fclient.PDUResult{ results[event.EventID()] = fclient.PDUResult{
Error: err.Error(), Error: err.Error(),

View file

@ -70,6 +70,10 @@ type FakeRsAPI struct {
bannedFromRoom bool bannedFromRoom bool
} }
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func (r *FakeRsAPI) QueryRoomVersionForRoom( func (r *FakeRsAPI) QueryRoomVersionForRoom(
ctx context.Context, ctx context.Context,
roomID string, roomID string,
@ -638,6 +642,10 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
} }
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func (t *testRoomserverAPI) InputRoomEvents( func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *rsAPI.InputRoomEventsRequest, request *rsAPI.InputRoomEventsRequest,

View file

@ -39,6 +39,7 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI ClientRoomserverAPI
UserRoomserverAPI UserRoomserverAPI
FederationRoomserverAPI FederationRoomserverAPI
QuerySenderIDAPI
// needed to avoid chicken and egg scenario when setting up the // needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs // interdependencies between the roomserver and other input APIs
@ -68,7 +69,7 @@ type InputRoomEventsAPI interface {
type QuerySenderIDAPI interface { type QuerySenderIDAPI interface {
// Accepts either roomID or alias // Accepts either roomID or alias
QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error)
QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error)
} }
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.
@ -98,6 +99,7 @@ type QueryEventsAPI interface {
type SyncRoomserverAPI interface { type SyncRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
@ -138,6 +140,7 @@ type SyncRoomserverAPI interface {
} }
type AppserviceRoomserverAPI interface { type AppserviceRoomserverAPI interface {
QuerySenderIDAPI
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID. // which room to use by querying the first events roomID.
QueryEventsByID( QueryEventsByID(
@ -197,6 +200,7 @@ type ClientRoomserverAPI interface {
} }
type UserRoomserverAPI interface { type UserRoomserverAPI interface {
QuerySenderIDAPI
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
KeyserverRoomserverAPI KeyserverRoomserverAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
@ -209,6 +213,8 @@ type FederationRoomserverAPI interface {
InputRoomEventsAPI InputRoomEventsAPI
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error

View file

@ -463,10 +463,10 @@ type MembershipQuerier struct {
Roomserver FederationRoomserverAPI Roomserver FederationRoomserverAPI
} }
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID string) (string, error) { func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
req := QueryMembershipForUserRequest{ req := QueryMembershipForUserRequest{
RoomID: roomID.String(), RoomID: roomID.String(),
UserID: senderID, UserID: string(senderID),
} }
res := QueryMembershipForUserResponse{} res := QueryMembershipForUserResponse{}
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res)

View file

@ -76,7 +76,9 @@ func CheckForSoftFail(
} }
// Check if the event is allowed. // Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
// return true, nil // return true, nil
return true, err return true, err
} }

View file

@ -128,7 +128,7 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent { if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
} }
sender, err := event.UserID() sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil { if err != nil {
return fmt.Errorf("event has invalid sender %q", event.SenderID()) return fmt.Errorf("event has invalid sender %q", event.SenderID())
} }
@ -276,7 +276,9 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then // Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted. // we consider the event to be "rejected" — it will still be persisted.
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
isRejected = true isRejected = true
rejectionErr = err rejectionErr = err
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
@ -579,7 +581,9 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents( stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent), gomatrixserverlib.ToPDUs(stateBeforeEvent),
) )
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
return return
} }
@ -690,7 +694,9 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply // Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem // skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway. // if a critical event is missing anyway.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
continue nextAuthEvent continue nextAuthEvent
} }
@ -706,7 +712,9 @@ nextAuthEvent:
} }
// Check if the auth event should be rejected. // Check if the auth event should be rejected.
err := gomatrixserverlib.Allowed(authEvent, auth) err := gomatrixserverlib.Allowed(authEvent, auth, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
})
if isRejected = err != nil; isRejected { if isRejected = err != nil; isRejected {
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
} }

View file

@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
} }
// Finally check that the event is NOT allowed // Finally check that the event is NOT allowed
if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil { if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomAliasOrID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil {
t.Fatalf("event should not be allowed, but it was") t.Fatalf("event should not be allowed, but it was")
} }
} }

View file

@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...) stateEventList = append(stateEventList, state.StateEvents...)
} }
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
},
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// apply the current event // apply the current event
retryAllowedState: retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true)
@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed. // will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
continue continue
} }
missingEvents = append(missingEvents, t.cacheAndReturn(ev)) missingEvents = append(missingEvents, t.cacheAndReturn(ev))
@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(), StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(), AuthEvents: state.GetAuthEvents(),
}, roomVersion, t.keys, nil) }, roomVersion, t.keys, nil, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
} }
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }
return t.cacheAndReturn(event), nil return t.cacheAndReturn(event), nil
} }
func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error { func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil) authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents { for i := range stateEvents {
err := authUsingState.AddEvent(stateEvents[i]) err := authUsingState.AddEvent(stateEvents[i])
@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli
return err return err
} }
} }
return gomatrixserverlib.Allowed(e, &authUsingState) return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender)
} }
func (t *missingStateReq) hadEvent(eventID string) { func (t *missingStateReq) hadEvent(eventID string) {

View file

@ -262,13 +262,17 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
} }
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
continue continue
} }
authEventMap[authEvent.EventID()] = authEvent authEventMap[authEvent.EventID()] = authEvent
} }
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
continue continue
} }
stateEventMap[stateEvent.EventID()] = stateEvent stateEventMap[stateEvent.EventID()] = stateEvent

View file

@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events" // Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
},
) )
// Only return an error if we really couldn't get any events. // Only return an error if we really couldn't get any events.
if err != nil && len(events) == 0 { if err != nil && len(events) == 0 {
@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue continue
} }
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
})
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to load and verify event") logger.WithError(err).Warn("failed to load and verify event")
continue continue
@ -484,7 +488,7 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates. // Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool) serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents { for _, event := range memberEvents {
if sender, err := event.UserID(); err == nil { if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
serverSet[sender.Domain()] = true serverSet[sender.Domain()] = true
} }
} }

View file

@ -308,7 +308,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return c.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,

View file

@ -126,7 +126,7 @@ func (r *Inviter) PerformInvite(
) error { ) error {
event := req.Event event := req.Event
sender, err := event.UserID() sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil { if err != nil {
return spec.InvalidParam("The sender user ID is invalid") return spec.InvalidParam("The sender user ID is invalid")
} }
@ -156,6 +156,9 @@ func (r *Inviter) PerformInvite(
StrippedState: req.InviteRoomState, StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB}, StateQuerier: &QueryState{r.DB},
UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
},
} }
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
if err != nil { if err != nil {

View file

@ -484,7 +484,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
} }
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
} }
@ -567,7 +569,9 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil { if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
} }

View file

@ -122,7 +122,9 @@ func (r *Queryer) QueryStateAfterEvents(
} }
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
},
) )
if err != nil { if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
@ -349,7 +351,9 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
})
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }
return nil return nil
@ -398,7 +402,9 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
})
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }
@ -588,7 +594,9 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState { if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
},
) )
if err != nil { if err != nil {
return err return err
@ -1036,15 +1044,9 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
} }
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) {
// TODO: implement this properly with pseudoIDs return r.DB.GetSenderIDForUser(ctx, roomAliasOrID, userID)
return userID.String(), nil
} }
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) { func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
// TODO: implement this properly with pseudoIDs return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
userID, err := spec.NewUserID(senderID, true)
if err != nil {
return spec.UserID{}, err
}
return *userID, err
} }

View file

@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -43,6 +44,7 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error)
} }
type StateResolution struct { type StateResolution struct {
@ -945,7 +947,9 @@ func (v *StateResolution) resolveConflictsV1(
} }
// Resolve the conflicts. // Resolve the conflicts.
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
})
// Map from the full events back to numeric state entries. // Map from the full events back to numeric state entries.
for _, resolvedEvent := range resolvedEvents { for _, resolvedEvent := range resolvedEvents {
@ -1057,6 +1061,9 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents, conflictedEvents,
nonConflictedEvents, nonConflictedEvents,
authEvents, authEvents,
func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
},
) )
}() }()

View file

@ -166,6 +166,10 @@ type Database interface {
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
// GetKnownUsers searches all users that userID knows about. // GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// GetKnownUsers tries to obtain the current mxid for a given user.
GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error)
// GetKnownUsers tries to obtain the current senderID for a given user.
GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error)
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error) GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
@ -211,6 +215,7 @@ type RoomDatabase interface {
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error)
} }
type EventDatabase interface { type EventDatabase interface {

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -250,3 +251,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }
func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
return u.d.GetUserIDForSender(ctx, roomAliasOrID, senderID)
}

View file

@ -988,13 +988,15 @@ func (d *EventDatabase) MaybeRedactEvent(
return nil return nil
} }
// TODO: Don't hack senderID into userID here (pseudoIDs)
sender1Domain := "" sender1Domain := ""
sender1, err1 := redactedEvent.UserID() sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true)
if err1 == nil { if err1 == nil {
sender1Domain = string(sender1.Domain()) sender1Domain = string(sender1.Domain())
} }
// TODO: Don't hack senderID into userID here (pseudoIDs)
sender2Domain := "" sender2Domain := ""
sender2, err2 := redactionEvent.UserID() sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true)
if err2 == nil { if err2 == nil {
sender2Domain = string(sender2.Domain()) sender2Domain = string(sender2.Domain())
} }
@ -1522,6 +1524,16 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
} }
func (d *Database) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
// TODO: Use real logic once DB for pseudoIDs is in place
return spec.NewUserID(senderID, true)
}
func (d *Database) GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) {
// TODO: Use real logic once DB for pseudoIDs is in place
return userID.String(), nil
}
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)

View file

@ -92,9 +92,11 @@ type MSC2836EventRelationshipsResponse struct {
ParsedAuthChain []gomatrixserverlib.PDU ParsedAuthChain []gomatrixserverlib.PDU
} }
func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
out := &EventRelationshipResponse{ out := &EventRelationshipResponse{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll), Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}),
Limited: res.Limited, Limited: res.Limited,
NextBatch: res.NextBatch, NextBatch: res.NextBatch,
} }
@ -187,7 +189,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: toClientResponse(res), JSON: toClientResponse(req.Context(), res, rsAPI),
} }
} }
} }

View file

@ -525,6 +525,10 @@ type testRoomserverAPI struct {
events map[string]*types.HeaderedEvent events map[string]*types.HeaderedEvent
} }
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
for _, eventID := range req.EventIDs { for _, eventID := range req.EventIDs {
ev := r.events[eventID] ev := r.events[eventID]

View file

@ -193,14 +193,20 @@ func Context(
} }
} }
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll) eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll) return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
newState := state newState := state
if filter.LazyLoadMembers { if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent) allEvents = append(allEvents, &requestedEvent)
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll) evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to load membership events") logrus.WithError(err).Error("unable to load membership events")
@ -211,12 +217,16 @@ func Context(
} }
} }
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll) ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
response := ContextRespsonse{ response := ContextRespsonse{
Event: &ev, Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient, EventsBefore: eventsBeforeClient,
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll), State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}),
} }
if len(response.State) > filter.Limit { if len(response.State) > filter.Limit {

View file

@ -103,6 +103,8 @@ func GetEvent(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll), JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
}),
} }
} }

View file

@ -153,6 +153,8 @@ func GetMemberships(
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll)}, JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
})},
} }
} }

View file

@ -241,7 +241,7 @@ func OnIncomingMessagesRequest(
device: device, device: device,
} }
clientEvents, start, end, err := mReq.retrieveEvents() clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed")
return util.JSONResponse{ return util.JSONResponse{
@ -273,7 +273,9 @@ func OnIncomingMessagesRequest(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...) res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
})...)
} }
// If we didn't return any events, set the end to an empty string, so it will be omitted // If we didn't return any events, set the end to an empty string, so it will be omitted
@ -310,7 +312,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.
// homeserver in the room for older events. // homeserver in the room for older events.
// Returns an error if there was an issue talking to the database or with the // Returns an error if there was an issue talking to the database or with the
// remote homeserver. // remote homeserver.
func (r *messagesReq) retrieveEvents() ( func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserverAPI) (
clientEvents []synctypes.ClientEvent, start, clientEvents []synctypes.ClientEvent, start,
end types.TopologyToken, err error, end types.TopologyToken, err error,
) { ) {
@ -382,7 +384,9 @@ func (r *messagesReq) retrieveEvents() (
"events_before": len(events), "events_before": len(events),
"events_after": len(filteredEvents), "events_after": len(filteredEvents),
}).Debug("applied history visibility (messages)") }).Debug("applied history visibility (messages)")
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll), start, end, err return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}), start, end, err
} }
func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) { func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) {

View file

@ -116,7 +116,9 @@ func Relations(
for _, event := range filteredEvents { for _, event := range filteredEvents {
res.Chunk = append( res.Chunk = append(
res.Chunk, res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll), synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
}),
) )
} }

View file

@ -171,7 +171,7 @@ func Setup(
nb := req.FormValue("next_batch") nb := req.FormValue("next_batch")
nextBatch = &nb nextBatch = &nb
} }
return Search(req, device, syncDB, fts, nextBatch) return Search(req, device, syncDB, fts, nextBatch, rsAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
@ -38,7 +39,7 @@ import (
) )
// nolint:gocyclo // nolint:gocyclo
func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse { func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string, rsAPI roomserverAPI.SyncRoomserverAPI) util.JSONResponse {
start := time.Now() start := time.Now()
var ( var (
searchReq SearchRequest searchReq SearchRequest
@ -225,14 +226,20 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
results = append(results, Result{ results = append(results, Result{
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
End: endToken.String(), End: endToken.String(),
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync), EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync), return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
ProfileInfo: profileInfos, }),
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
}),
ProfileInfo: profileInfos,
}, },
Rank: eventScore[event.EventID()].Score, Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll), Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
}),
}) })
roomGroup := groups[event.RoomID()] roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID()) roomGroup.Results = append(roomGroup.Results, event.EventID())
@ -247,7 +254,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync) stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
})
} }
} }

View file

@ -2,6 +2,7 @@ package routing
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -9,6 +10,7 @@ import (
"github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types" rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
@ -21,6 +23,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestSearch(t *testing.T) { func TestSearch(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceDevice := userapi.Device{UserID: alice.ID} aliceDevice := userapi.Device{UserID: alice.ID}
@ -247,7 +255,7 @@ func TestSearch(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", reqBody) req := httptest.NewRequest(http.MethodPost, "/", reqBody)
res := Search(req, tc.device, db, fts, tc.from) res := Search(req, tc.device, db, fts, tc.from, &FakeSyncRoomserverAPI{})
if !tc.wantOK && !res.Is2xx() { if !tc.wantOK && !res.Is2xx() {
return return
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -17,6 +18,7 @@ import (
type InviteStreamProvider struct { type InviteStreamProvider struct {
DefaultStreamProvider DefaultStreamProvider
rsAPI api.SyncRoomserverAPI
} }
func (p *InviteStreamProvider) Setup( func (p *InviteStreamProvider) Setup(
@ -66,7 +68,9 @@ func (p *InviteStreamProvider) IncrementalSync(
if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok { if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok {
continue continue
} }
ir := types.NewInviteResponse(inviteEvent) ir := types.NewInviteResponse(inviteEvent, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
req.Response.Rooms.Invite[roomID] = ir req.Response.Rooms.Invite[roomID] = ir
} }

View file

@ -376,20 +376,28 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
} }
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
req.Response.Rooms.Join[delta.RoomID] = jr req.Response.Rooms.Join[delta.RoomID] = jr
case spec.Peek: case spec.Peek:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms // TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync) jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
req.Response.Rooms.Peek[delta.RoomID] = jr req.Response.Rooms.Peek[delta.RoomID] = jr
case spec.Leave: case spec.Leave:
@ -398,11 +406,15 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case spec.Ban: case spec.Ban:
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
req.Response.Rooms.Leave[delta.RoomID] = lr req.Response.Rooms.Leave[delta.RoomID] = lr
} }
@ -552,11 +564,15 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync) jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
return jr, nil return jr, nil
} }

View file

@ -45,6 +45,7 @@ func NewSyncStreamProviders(
}, },
InviteStreamProvider: &InviteStreamProvider{ InviteStreamProvider: &InviteStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
rsAPI: rsAPI,
}, },
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},

View file

@ -40,6 +40,10 @@ type syncRoomserverAPI struct {
rooms []*test.Room rooms []*test.Room
} }
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {
var room *test.Room var room *test.Room
for _, r := range s.rooms { for _, r := range s.rooms {

View file

@ -44,21 +44,21 @@ type ClientEvent struct {
} }
// ToClientEvents converts server events to client events. // ToClientEvents converts server events to client events.
func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) []ClientEvent { func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) []ClientEvent {
evs := make([]ClientEvent, 0, len(serverEvs)) evs := make([]ClientEvent, 0, len(serverEvs))
for _, se := range serverEvs { for _, se := range serverEvs {
if se == nil { if se == nil {
continue // TODO: shouldn't happen? continue // TODO: shouldn't happen?
} }
evs = append(evs, ToClientEvent(se, format)) evs = append(evs, ToClientEvent(se, format, userIDForSender))
} }
return evs return evs
} }
// ToClientEvent converts a single server event to a client event. // ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent { func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) ClientEvent {
user := "" user := ""
userID, err := se.UserID() userID, err := userIDForSender(se.RoomID(), se.SenderID())
if err == nil { if err == nil {
user = userID.String() user = userID.String()
} }

View file

@ -21,8 +21,13 @@ import (
"testing" "testing"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestToClientEvent(t *testing.T) { // nolint: gocyclo func TestToClientEvent(t *testing.T) { // nolint: gocyclo
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{
"type": "m.room.name", "type": "m.room.name",
@ -43,7 +48,7 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil { if err != nil {
t.Fatalf("failed to create Event: %s", err) t.Fatalf("failed to create Event: %s", err)
} }
ce := ToClientEvent(ev, FormatAll) ce := ToClientEvent(ev, FormatAll, UserIDForSender)
if ce.EventID != ev.EventID() { if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
} }
@ -63,7 +68,7 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned))
} }
user := "" user := ""
userID, err := ev.UserID() userID, err := UserIDForSender("", ev.SenderID())
if err == nil { if err == nil {
user = userID.String() user = userID.String()
} }
@ -103,7 +108,7 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create Event: %s", err) t.Fatalf("failed to create Event: %s", err)
} }
ce := ToClientEvent(ev, FormatSync) ce := ToClientEvent(ev, FormatSync, UserIDForSender)
if ce.RoomID != "" { if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
} }

View file

@ -539,7 +539,7 @@ type InviteResponse struct {
} }
// NewInviteResponse creates an empty response with initialised arrays. // NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse { func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDForSender) *InviteResponse {
res := InviteResponse{} res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{} res.InviteState.Events = []json.RawMessage{}
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse {
// Then we'll see if we can create a partial of the invite event itself. // Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite. // This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync) inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userIDForSender)
inviteEvent.Unsigned = nil inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil { if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev) res.InviteState.Events = append(res.InviteState.Events, ev)

View file

@ -8,8 +8,13 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestSyncTokens(t *testing.T) { func TestSyncTokens(t *testing.T) {
shouldPass := map[string]string{ shouldPass := map[string]string{
"s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(), "s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(),
@ -56,7 +61,7 @@ func TestNewInviteResponse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}) res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, UserIDForSender)
j, err := json.Marshal(res) j, err := json.Marshal(res)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -39,6 +39,10 @@ var (
roomIDCounter = int64(0) roomIDCounter = int64(0)
) )
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
type Room struct { type Room struct {
ID string ID string
Version gomatrixserverlib.RoomVersion Version gomatrixserverlib.RoomVersion
@ -195,7 +199,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
if err != nil { if err != nil {
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err) t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
} }
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
} }
headeredEvent := &rstypes.HeaderedEvent{PDU: ev} headeredEvent := &rstypes.HeaderedEvent{PDU: ev}

View file

@ -301,7 +301,9 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch { switch {
case event.Type() == spec.MRoomMember: case event.Type() == spec.MRoomMember:
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll) cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
var member *localMembership var member *localMembership
member, err = newLocalMembership(&cevent) member, err = newLocalMembership(&cevent)
if err != nil { if err != nil {
@ -534,7 +536,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// UNSPEC: the spec doesn't say this is a ClientEvent, but the // UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which // fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync. // matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync), Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
}),
// TODO: this is per-device, but it's not part of the primary // TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't // key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it // make sense. What is this supposed to be? Sytests require it
@ -616,7 +620,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// user. Returns actions (including dont_notify). // user. Returns actions (including dont_notify).
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
user := "" user := ""
userID, err := event.UserID() userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil { if err == nil {
user = userID.String() user = userID.String()
} }
@ -655,7 +659,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize, roomSize: roomSize,
} }
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
rule, err := eval.MatchEvent(event.PDU) rule, err := eval.MatchEvent(event.PDU, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
return &types.HeaderedEvent{PDU: ev} return &types.HeaderedEvent{PDU: ev}
} }
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func Test_evaluatePushRules(t *testing.T) { func Test_evaluatePushRules(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
consumer := OutputRoomEventConsumer{db: db} consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}}
testCases := []struct { testCases := []struct {
name string name string

View file

@ -11,6 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -22,6 +23,10 @@ import (
userUtil "github.com/matrix-org/dendrite/userapi/util" userUtil "github.com/matrix-org/dendrite/userapi/util"
) )
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestNotifyUserCountsAsync(t *testing.T) { func TestNotifyUserCountsAsync(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID)
@ -100,7 +105,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
// Insert a dummy event // Insert a dummy event
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll), Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, UserIDForSender),
}); err != nil { }); err != nil {
t.Error(err) t.Error(err)
} }