mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-11 08:03:09 -06:00
Remove pdu.UserID & pass in the required accessor
This commit is contained in:
parent
3582ec7de7
commit
d7e41177f3
|
|
@ -181,7 +181,9 @@ func (s *OutputRoomEventConsumer) sendEvents(
|
||||||
// Create the transaction body.
|
// Create the transaction body.
|
||||||
transaction, err := json.Marshal(
|
transaction, err := json.Marshal(
|
||||||
ApplicationServiceTransaction{
|
ApplicationServiceTransaction{
|
||||||
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll),
|
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -234,7 +236,7 @@ func (s *appserviceState) backoffAndPause(err error) error {
|
||||||
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
|
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
|
||||||
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
|
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
|
||||||
user := ""
|
user := ""
|
||||||
userID, err := event.UserID()
|
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
user = userID.String()
|
user = userID.String()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -331,7 +331,9 @@ func generateSendEvent(
|
||||||
stateEvents[i] = queryRes.StateEvents[i].PDU
|
stateEvents[i] = queryRes.StateEvents[i].PDU
|
||||||
}
|
}
|
||||||
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
|
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
|
||||||
if err = gomatrixserverlib.Allowed(e.PDU, &provider); err != nil {
|
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusForbidden,
|
Code: http.StatusForbidden,
|
||||||
JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client?
|
JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client?
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
||||||
for _, ev := range stateRes.StateEvents {
|
for _, ev := range stateRes.StateEvents {
|
||||||
stateEvents = append(
|
stateEvents = append(
|
||||||
stateEvents,
|
stateEvents,
|
||||||
synctypes.ToClientEvent(ev, synctypes.FormatAll),
|
synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -164,7 +166,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
||||||
for _, ev := range stateAfterRes.StateEvents {
|
for _, ev := range stateAfterRes.StateEvents {
|
||||||
stateEvents = append(
|
stateEvents = append(
|
||||||
stateEvents,
|
stateEvents,
|
||||||
synctypes.ToClientEvent(ev, synctypes.FormatAll),
|
synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -335,7 +339,9 @@ func OnIncomingStateTypeRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvent := stateEventInStateResp{
|
stateEvent := stateEventInStateResp{
|
||||||
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll),
|
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
var res interface{}
|
var res interface{}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This is a utility for inspecting state snapshots and running state resolution
|
// This is a utility for inspecting state snapshots and running state resolution
|
||||||
|
|
@ -182,7 +183,9 @@ func main() {
|
||||||
fmt.Println("Resolving state")
|
fmt.Println("Resolving state")
|
||||||
var resolved Events
|
var resolved Events
|
||||||
resolved, err = gomatrixserverlib.ResolveConflicts(
|
resolved, err = gomatrixserverlib.ResolveConflicts(
|
||||||
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents,
|
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return roomserverDB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,10 @@ type fedRoomserverAPI struct {
|
||||||
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
|
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
// PerformJoin will call this function
|
// PerformJoin will call this function
|
||||||
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
|
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
|
||||||
if f.inputRoomEvents == nil {
|
if f.inputRoomEvents == nil {
|
||||||
|
|
|
||||||
|
|
@ -156,15 +156,20 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
}
|
}
|
||||||
|
|
||||||
joinInput := gomatrixserverlib.PerformJoinInput{
|
joinInput := gomatrixserverlib.PerformJoinInput{
|
||||||
UserID: user,
|
UserID: user,
|
||||||
RoomID: room,
|
RoomID: room,
|
||||||
ServerName: serverName,
|
ServerName: serverName,
|
||||||
Content: content,
|
Content: content,
|
||||||
Unsigned: unsigned,
|
Unsigned: unsigned,
|
||||||
PrivateKey: r.cfg.Matrix.PrivateKey,
|
PrivateKey: r.cfg.Matrix.PrivateKey,
|
||||||
KeyID: r.cfg.Matrix.KeyID,
|
KeyID: r.cfg.Matrix.KeyID,
|
||||||
KeyRing: r.keyRing,
|
KeyRing: r.keyRing,
|
||||||
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName),
|
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
|
UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
|
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
|
||||||
|
|
||||||
|
|
@ -358,8 +363,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
|
||||||
|
|
||||||
// authenticate the state returned (check its auth events etc)
|
// authenticate the state returned (check its auth events etc)
|
||||||
// the equivalent of CheckSendJoinResponse()
|
// the equivalent of CheckSendJoinResponse()
|
||||||
|
userIDProvider := func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}
|
||||||
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
|
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
|
||||||
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName),
|
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error checking state returned from peeking: %w", err)
|
return fmt.Errorf("error checking state returned from peeking: %w", err)
|
||||||
|
|
@ -509,7 +517,7 @@ func (r *FederationInternalAPI) SendInvite(
|
||||||
event gomatrixserverlib.PDU,
|
event gomatrixserverlib.PDU,
|
||||||
strippedState []gomatrixserverlib.InviteStrippedState,
|
strippedState []gomatrixserverlib.InviteStrippedState,
|
||||||
) (gomatrixserverlib.PDU, error) {
|
) (gomatrixserverlib.PDU, error) {
|
||||||
inviter, err := event.UserID()
|
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -640,6 +648,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error {
|
||||||
func federatedEventProvider(
|
func federatedEventProvider(
|
||||||
ctx context.Context, federation fclient.FederationClient,
|
ctx context.Context, federation fclient.FederationClient,
|
||||||
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
|
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
|
||||||
|
userIDForSender spec.UserIDForSender,
|
||||||
) gomatrixserverlib.EventProvider {
|
) gomatrixserverlib.EventProvider {
|
||||||
// A list of events that we have retried, if they were not included in
|
// A list of events that we have retried, if they were not included in
|
||||||
// the auth events supplied in the send_join.
|
// the auth events supplied in the send_join.
|
||||||
|
|
@ -689,7 +698,7 @@ func federatedEventProvider(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the signatures of the event.
|
// Check the signatures of the event.
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil {
|
if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil {
|
||||||
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
|
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -147,15 +147,18 @@ func MakeJoin(
|
||||||
}
|
}
|
||||||
|
|
||||||
input := gomatrixserverlib.HandleMakeJoinInput{
|
input := gomatrixserverlib.HandleMakeJoinInput{
|
||||||
Context: httpReq.Context(),
|
Context: httpReq.Context(),
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
RoomVersion: roomVersion,
|
RoomVersion: roomVersion,
|
||||||
RemoteVersions: remoteVersions,
|
RemoteVersions: remoteVersions,
|
||||||
RequestOrigin: request.Origin(),
|
RequestOrigin: request.Origin(),
|
||||||
LocalServerName: cfg.Matrix.ServerName,
|
LocalServerName: cfg.Matrix.ServerName,
|
||||||
LocalServerInRoom: res.RoomExists && res.IsInRoom,
|
LocalServerInRoom: res.RoomExists && res.IsInRoom,
|
||||||
RoomQuerier: &roomQuerier,
|
RoomQuerier: &roomQuerier,
|
||||||
|
UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
BuildEventTemplate: createJoinTemplate,
|
BuildEventTemplate: createJoinTemplate,
|
||||||
}
|
}
|
||||||
response, internalErr := gomatrixserverlib.HandleMakeJoin(input)
|
response, internalErr := gomatrixserverlib.HandleMakeJoin(input)
|
||||||
|
|
|
||||||
|
|
@ -223,7 +223,7 @@ func SendLeave(
|
||||||
// Check that the sender belongs to the server that is sending us
|
// Check that the sender belongs to the server that is sending us
|
||||||
// the request. By this point we've already asserted that the sender
|
// the request. By this point we've already asserted that the sender
|
||||||
// and the state key are equal so we don't need to check both.
|
// and the state key are equal so we don't need to check both.
|
||||||
sender, err := event.UserID()
|
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusForbidden,
|
Code: http.StatusForbidden,
|
||||||
|
|
|
||||||
4
go.mod
4
go.mod
|
|
@ -22,7 +22,7 @@ require (
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
||||||
github.com/mattn/go-sqlite3 v1.14.16
|
github.com/mattn/go-sqlite3 v1.14.16
|
||||||
|
|
@ -34,7 +34,7 @@ require (
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/prometheus/client_golang v1.13.0
|
github.com/prometheus/client_golang v1.13.0
|
||||||
github.com/sirupsen/logrus v1.9.2
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/stretchr/testify v1.8.2
|
github.com/stretchr/testify v1.8.2
|
||||||
github.com/tidwall/gjson v1.14.4
|
github.com/tidwall/gjson v1.14.4
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
|
|
|
||||||
8
go.sum
8
go.sum
|
|
@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced h1:pbCM+nno+r2wW3jwxP65xmkzk6008CdMNZaOWYBwB1c=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30 h1:G+Do1UoWazY0Fetq+eAX1h1+fimf19NGGyaS86hWg8s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
||||||
|
|
@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
|
||||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||||
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
||||||
github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
|
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
|
||||||
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
|
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
|
||||||
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
|
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A RuleSetEvaluator encapsulates context to evaluate an event
|
// A RuleSetEvaluator encapsulates context to evaluate an event
|
||||||
|
|
@ -53,7 +54,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat
|
||||||
|
|
||||||
// MatchEvent returns the first matching rule. Returns nil if there
|
// MatchEvent returns the first matching rule. Returns nil if there
|
||||||
// was no match rule.
|
// was no match rule.
|
||||||
func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, error) {
|
func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) (*Rule, error) {
|
||||||
// TODO: server-default rules have lower priority than user rules,
|
// TODO: server-default rules have lower priority than user rules,
|
||||||
// but they are stored together with the user rules. It's a bit
|
// but they are stored together with the user rules. It's a bit
|
||||||
// unclear what the specification (11.14.1.4 Predefined rules)
|
// unclear what the specification (11.14.1.4 Predefined rules)
|
||||||
|
|
@ -68,7 +69,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err
|
||||||
if rule.Default != defRules {
|
if rule.Default != defRules {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
|
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec, userIDForSender)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -83,7 +84,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) {
|
func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext, userIDForSender spec.UserIDForSender) (bool, error) {
|
||||||
if !rule.Enabled {
|
if !rule.Enabled {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
@ -114,7 +115,7 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
|
||||||
|
|
||||||
case SenderKind:
|
case SenderKind:
|
||||||
userID := ""
|
userID := ""
|
||||||
sender, err := event.UserID()
|
sender, err := userIDForSender(event.RoomID(), event.SenderID())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
userID = sender.String()
|
userID = sender.String()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,13 @@ import (
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
|
func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
|
||||||
ev := mustEventFromJSON(t, `{}`)
|
ev := mustEventFromJSON(t, `{}`)
|
||||||
defaultEnabled := &Rule{
|
defaultEnabled := &Rule{
|
||||||
|
|
@ -45,7 +50,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
|
||||||
for _, tst := range tsts {
|
for _, tst := range tsts {
|
||||||
t.Run(tst.Name, func(t *testing.T) {
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet)
|
rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet)
|
||||||
got, err := rse.MatchEvent(tst.Event)
|
got, err := rse.MatchEvent(tst.Event, UserIDForSender)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("MatchEvent failed: %v", err)
|
t.Fatalf("MatchEvent failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -90,7 +95,7 @@ func TestRuleMatches(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tst := range tsts {
|
for _, tst := range tsts {
|
||||||
t.Run(tst.Name, func(t *testing.T) {
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
|
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil, UserIDForSender)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ruleMatches failed: %v", err)
|
t.Fatalf("ruleMatches failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return t.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
||||||
results[event.EventID()] = fclient.PDUResult{
|
results[event.EventID()] = fclient.PDUResult{
|
||||||
Error: err.Error(),
|
Error: err.Error(),
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,10 @@ type FakeRsAPI struct {
|
||||||
bannedFromRoom bool
|
bannedFromRoom bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *FakeRsAPI) QueryRoomVersionForRoom(
|
func (r *FakeRsAPI) QueryRoomVersionForRoom(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomID string,
|
roomID string,
|
||||||
|
|
@ -638,6 +642,10 @@ type testRoomserverAPI struct {
|
||||||
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
|
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *testRoomserverAPI) InputRoomEvents(
|
func (t *testRoomserverAPI) InputRoomEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *rsAPI.InputRoomEventsRequest,
|
request *rsAPI.InputRoomEventsRequest,
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ type RoomserverInternalAPI interface {
|
||||||
ClientRoomserverAPI
|
ClientRoomserverAPI
|
||||||
UserRoomserverAPI
|
UserRoomserverAPI
|
||||||
FederationRoomserverAPI
|
FederationRoomserverAPI
|
||||||
|
QuerySenderIDAPI
|
||||||
|
|
||||||
// needed to avoid chicken and egg scenario when setting up the
|
// needed to avoid chicken and egg scenario when setting up the
|
||||||
// interdependencies between the roomserver and other input APIs
|
// interdependencies between the roomserver and other input APIs
|
||||||
|
|
@ -68,7 +69,7 @@ type InputRoomEventsAPI interface {
|
||||||
type QuerySenderIDAPI interface {
|
type QuerySenderIDAPI interface {
|
||||||
// Accepts either roomID or alias
|
// Accepts either roomID or alias
|
||||||
QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error)
|
QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error)
|
||||||
QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error)
|
QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query the latest events and state for a room from the room server.
|
// Query the latest events and state for a room from the room server.
|
||||||
|
|
@ -98,6 +99,7 @@ type QueryEventsAPI interface {
|
||||||
type SyncRoomserverAPI interface {
|
type SyncRoomserverAPI interface {
|
||||||
QueryLatestEventsAndStateAPI
|
QueryLatestEventsAndStateAPI
|
||||||
QueryBulkStateContentAPI
|
QueryBulkStateContentAPI
|
||||||
|
QuerySenderIDAPI
|
||||||
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||||
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
|
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
|
||||||
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
||||||
|
|
@ -138,6 +140,7 @@ type SyncRoomserverAPI interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppserviceRoomserverAPI interface {
|
type AppserviceRoomserverAPI interface {
|
||||||
|
QuerySenderIDAPI
|
||||||
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
||||||
// which room to use by querying the first events roomID.
|
// which room to use by querying the first events roomID.
|
||||||
QueryEventsByID(
|
QueryEventsByID(
|
||||||
|
|
@ -197,6 +200,7 @@ type ClientRoomserverAPI interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRoomserverAPI interface {
|
type UserRoomserverAPI interface {
|
||||||
|
QuerySenderIDAPI
|
||||||
QueryLatestEventsAndStateAPI
|
QueryLatestEventsAndStateAPI
|
||||||
KeyserverRoomserverAPI
|
KeyserverRoomserverAPI
|
||||||
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
|
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
|
||||||
|
|
@ -209,6 +213,8 @@ type FederationRoomserverAPI interface {
|
||||||
InputRoomEventsAPI
|
InputRoomEventsAPI
|
||||||
QueryLatestEventsAndStateAPI
|
QueryLatestEventsAndStateAPI
|
||||||
QueryBulkStateContentAPI
|
QueryBulkStateContentAPI
|
||||||
|
QuerySenderIDAPI
|
||||||
|
|
||||||
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||||
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
||||||
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
|
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
|
||||||
|
|
|
||||||
|
|
@ -463,10 +463,10 @@ type MembershipQuerier struct {
|
||||||
Roomserver FederationRoomserverAPI
|
Roomserver FederationRoomserverAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID string) (string, error) {
|
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
|
||||||
req := QueryMembershipForUserRequest{
|
req := QueryMembershipForUserRequest{
|
||||||
RoomID: roomID.String(),
|
RoomID: roomID.String(),
|
||||||
UserID: senderID,
|
UserID: string(senderID),
|
||||||
}
|
}
|
||||||
res := QueryMembershipForUserResponse{}
|
res := QueryMembershipForUserResponse{}
|
||||||
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res)
|
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res)
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,9 @@ func CheckForSoftFail(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the event is allowed.
|
// Check if the event is allowed.
|
||||||
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil {
|
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
// return true, nil
|
// return true, nil
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
if roomInfo == nil && !isCreateEvent {
|
if roomInfo == nil && !isCreateEvent {
|
||||||
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
|
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
|
||||||
}
|
}
|
||||||
sender, err := event.UserID()
|
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("event has invalid sender %q", event.SenderID())
|
return fmt.Errorf("event has invalid sender %q", event.SenderID())
|
||||||
}
|
}
|
||||||
|
|
@ -276,7 +276,9 @@ func (r *Inputer) processRoomEvent(
|
||||||
|
|
||||||
// Check if the event is allowed by its auth events. If it isn't then
|
// Check if the event is allowed by its auth events. If it isn't then
|
||||||
// we consider the event to be "rejected" — it will still be persisted.
|
// we consider the event to be "rejected" — it will still be persisted.
|
||||||
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
|
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
isRejected = true
|
isRejected = true
|
||||||
rejectionErr = err
|
rejectionErr = err
|
||||||
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
|
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
|
||||||
|
|
@ -579,7 +581,9 @@ func (r *Inputer) processStateBefore(
|
||||||
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
|
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
|
||||||
gomatrixserverlib.ToPDUs(stateBeforeEvent),
|
gomatrixserverlib.ToPDUs(stateBeforeEvent),
|
||||||
)
|
)
|
||||||
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil {
|
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); rejectionErr != nil {
|
||||||
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
|
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -690,7 +694,9 @@ nextAuthEvent:
|
||||||
// Check the signatures of the event. If this fails then we'll simply
|
// Check the signatures of the event. If this fails then we'll simply
|
||||||
// skip it, because gomatrixserverlib.Allowed() will notice a problem
|
// skip it, because gomatrixserverlib.Allowed() will notice a problem
|
||||||
// if a critical event is missing anyway.
|
// if a critical event is missing anyway.
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil {
|
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
continue nextAuthEvent
|
continue nextAuthEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -706,7 +712,9 @@ nextAuthEvent:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the auth event should be rejected.
|
// Check if the auth event should be rejected.
|
||||||
err := gomatrixserverlib.Allowed(authEvent, auth)
|
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
if isRejected = err != nil; isRejected {
|
if isRejected = err != nil; isRejected {
|
||||||
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
|
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally check that the event is NOT allowed
|
// Finally check that the event is NOT allowed
|
||||||
if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil {
|
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomAliasOrID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil {
|
||||||
t.Fatalf("event should not be allowed, but it was")
|
t.Fatalf("event should not be allowed, but it was")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
|
||||||
stateEventList = append(stateEventList, state.StateEvents...)
|
stateEventList = append(stateEventList, state.StateEvents...)
|
||||||
}
|
}
|
||||||
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
|
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
|
||||||
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList),
|
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// apply the current event
|
// apply the current event
|
||||||
retryAllowedState:
|
retryAllowedState:
|
||||||
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil {
|
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
switch missing := err.(type) {
|
switch missing := err.(type) {
|
||||||
case gomatrixserverlib.MissingAuthEventError:
|
case gomatrixserverlib.MissingAuthEventError:
|
||||||
h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true)
|
h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true)
|
||||||
|
|
@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
|
||||||
// will be added and duplicates will be removed.
|
// will be added and duplicates will be removed.
|
||||||
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
|
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
|
||||||
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
missingEvents = append(missingEvents, t.cacheAndReturn(ev))
|
missingEvents = append(missingEvents, t.cacheAndReturn(ev))
|
||||||
|
|
@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState(
|
||||||
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
|
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
|
||||||
StateEvents: state.GetStateEvents(),
|
StateEvents: state.GetStateEvents(),
|
||||||
AuthEvents: state.GetAuthEvents(),
|
AuthEvents: state.GetAuthEvents(),
|
||||||
}, roomVersion, t.keys, nil)
|
}, roomVersion, t.keys, nil, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
|
||||||
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
|
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
|
||||||
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
|
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
|
||||||
}
|
}
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil {
|
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
||||||
return nil, verifySigError{event.EventID(), err}
|
return nil, verifySigError{event.EventID(), err}
|
||||||
}
|
}
|
||||||
return t.cacheAndReturn(event), nil
|
return t.cacheAndReturn(event), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error {
|
func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error {
|
||||||
authUsingState := gomatrixserverlib.NewAuthEvents(nil)
|
authUsingState := gomatrixserverlib.NewAuthEvents(nil)
|
||||||
for i := range stateEvents {
|
for i := range stateEvents {
|
||||||
err := authUsingState.AddEvent(stateEvents[i])
|
err := authUsingState.AddEvent(stateEvents[i])
|
||||||
|
|
@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return gomatrixserverlib.Allowed(e, &authUsingState)
|
return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *missingStateReq) hadEvent(eventID string) {
|
func (t *missingStateReq) hadEvent(eventID string) {
|
||||||
|
|
|
||||||
|
|
@ -262,13 +262,17 @@ func (r *Admin) PerformAdminDownloadState(
|
||||||
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
|
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
|
||||||
}
|
}
|
||||||
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
authEventMap[authEvent.EventID()] = authEvent
|
authEventMap[authEvent.EventID()] = authEvent
|
||||||
}
|
}
|
||||||
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
stateEventMap[stateEvent.EventID()] = stateEvent
|
stateEventMap[stateEvent.EventID()] = stateEvent
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
||||||
// Specifically the test "Outbound federation can backfill events"
|
// Specifically the test "Outbound federation can backfill events"
|
||||||
events, err := gomatrixserverlib.RequestBackfill(
|
events, err := gomatrixserverlib.RequestBackfill(
|
||||||
ctx, req.VirtualHost, requester,
|
ctx, req.VirtualHost, requester,
|
||||||
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
|
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
// Only return an error if we really couldn't get any events.
|
// Only return an error if we really couldn't get any events.
|
||||||
if err != nil && len(events) == 0 {
|
if err != nil && len(events) == 0 {
|
||||||
|
|
@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
||||||
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
|
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Warn("failed to load and verify event")
|
logger.WithError(err).Warn("failed to load and verify event")
|
||||||
continue
|
continue
|
||||||
|
|
@ -484,7 +488,7 @@ FindSuccessor:
|
||||||
// Store the server names in a temporary map to avoid duplicates.
|
// Store the server names in a temporary map to avoid duplicates.
|
||||||
serverSet := make(map[spec.ServerName]bool)
|
serverSet := make(map[spec.ServerName]bool)
|
||||||
for _, event := range memberEvents {
|
for _, event := range memberEvents {
|
||||||
if sender, err := event.UserID(); err == nil {
|
if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
|
||||||
serverSet[sender.Domain()] = true
|
serverSet[sender.Domain()] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -308,7 +308,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil {
|
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return c.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
|
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
|
||||||
return "", &util.JSONResponse{
|
return "", &util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusInternalServerError,
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,7 @@ func (r *Inviter) PerformInvite(
|
||||||
) error {
|
) error {
|
||||||
event := req.Event
|
event := req.Event
|
||||||
|
|
||||||
sender, err := event.UserID()
|
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return spec.InvalidParam("The sender user ID is invalid")
|
return spec.InvalidParam("The sender user ID is invalid")
|
||||||
}
|
}
|
||||||
|
|
@ -156,6 +156,9 @@ func (r *Inviter) PerformInvite(
|
||||||
StrippedState: req.InviteRoomState,
|
StrippedState: req.InviteRoomState,
|
||||||
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
|
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
|
||||||
StateQuerier: &QueryState{r.DB},
|
StateQuerier: &QueryState{r.DB},
|
||||||
|
UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -484,7 +484,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
|
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
|
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -567,7 +569,9 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
|
||||||
stateEvents[i] = queryRes.StateEvents[i].PDU
|
stateEvents[i] = queryRes.StateEvents[i].PDU
|
||||||
}
|
}
|
||||||
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
|
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
|
||||||
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil {
|
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}); err != nil {
|
||||||
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
|
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,9 @@ func (r *Queryer) QueryStateAfterEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
||||||
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
|
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
|
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
|
||||||
|
|
@ -349,7 +351,9 @@ func (r *Queryer) QueryMembershipsForRoom(
|
||||||
return fmt.Errorf("r.DB.Events: %w", err)
|
return fmt.Errorf("r.DB.Events: %w", err)
|
||||||
}
|
}
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
|
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -398,7 +402,9 @@ func (r *Queryer) QueryMembershipsForRoom(
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
|
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -588,7 +594,9 @@ func (r *Queryer) QueryStateAndAuthChain(
|
||||||
|
|
||||||
if request.ResolveState {
|
if request.ResolveState {
|
||||||
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
||||||
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
|
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -1036,15 +1044,9 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) {
|
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) {
|
||||||
// TODO: implement this properly with pseudoIDs
|
return r.DB.GetSenderIDForUser(ctx, roomAliasOrID, userID)
|
||||||
return userID.String(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) {
|
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
// TODO: implement this properly with pseudoIDs
|
return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
userID, err := spec.NewUserID(senderID, true)
|
|
||||||
if err != nil {
|
|
||||||
return spec.UserID{}, err
|
|
||||||
}
|
|
||||||
return *userID, err
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
|
|
@ -43,6 +44,7 @@ type StateResolutionStorage interface {
|
||||||
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||||
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
||||||
|
GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateResolution struct {
|
type StateResolution struct {
|
||||||
|
|
@ -945,7 +947,9 @@ func (v *StateResolution) resolveConflictsV1(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve the conflicts.
|
// Resolve the conflicts.
|
||||||
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents)
|
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
|
|
||||||
// Map from the full events back to numeric state entries.
|
// Map from the full events back to numeric state entries.
|
||||||
for _, resolvedEvent := range resolvedEvents {
|
for _, resolvedEvent := range resolvedEvents {
|
||||||
|
|
@ -1057,6 +1061,9 @@ func (v *StateResolution) resolveConflictsV2(
|
||||||
conflictedEvents,
|
conflictedEvents,
|
||||||
nonConflictedEvents,
|
nonConflictedEvents,
|
||||||
authEvents,
|
authEvents,
|
||||||
|
func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -166,6 +166,10 @@ type Database interface {
|
||||||
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
|
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
|
||||||
// GetKnownUsers searches all users that userID knows about.
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
||||||
|
// GetKnownUsers tries to obtain the current mxid for a given user.
|
||||||
|
GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error)
|
||||||
|
// GetKnownUsers tries to obtain the current senderID for a given user.
|
||||||
|
GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error)
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
GetKnownRooms(ctx context.Context) ([]string, error)
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
|
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
|
||||||
|
|
@ -211,6 +215,7 @@ type RoomDatabase interface {
|
||||||
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
||||||
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
||||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
|
||||||
|
GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type EventDatabase interface {
|
type EventDatabase interface {
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
@ -250,3 +251,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||||
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return u.d.GetUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -988,13 +988,15 @@ func (d *EventDatabase) MaybeRedactEvent(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Don't hack senderID into userID here (pseudoIDs)
|
||||||
sender1Domain := ""
|
sender1Domain := ""
|
||||||
sender1, err1 := redactedEvent.UserID()
|
sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true)
|
||||||
if err1 == nil {
|
if err1 == nil {
|
||||||
sender1Domain = string(sender1.Domain())
|
sender1Domain = string(sender1.Domain())
|
||||||
}
|
}
|
||||||
|
// TODO: Don't hack senderID into userID here (pseudoIDs)
|
||||||
sender2Domain := ""
|
sender2Domain := ""
|
||||||
sender2, err2 := redactionEvent.UserID()
|
sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true)
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
sender2Domain = string(sender2.Domain())
|
sender2Domain = string(sender2.Domain())
|
||||||
}
|
}
|
||||||
|
|
@ -1522,6 +1524,16 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
|
||||||
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
|
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
// TODO: Use real logic once DB for pseudoIDs is in place
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) {
|
||||||
|
// TODO: Use real logic once DB for pseudoIDs is in place
|
||||||
|
return userID.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||||
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
|
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
|
||||||
|
|
|
||||||
|
|
@ -92,9 +92,11 @@ type MSC2836EventRelationshipsResponse struct {
|
||||||
ParsedAuthChain []gomatrixserverlib.PDU
|
ParsedAuthChain []gomatrixserverlib.PDU
|
||||||
}
|
}
|
||||||
|
|
||||||
func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse {
|
func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
|
||||||
out := &EventRelationshipResponse{
|
out := &EventRelationshipResponse{
|
||||||
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll),
|
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
Limited: res.Limited,
|
Limited: res.Limited,
|
||||||
NextBatch: res.NextBatch,
|
NextBatch: res.NextBatch,
|
||||||
}
|
}
|
||||||
|
|
@ -187,7 +189,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: toClientResponse(res),
|
JSON: toClientResponse(req.Context(), res, rsAPI),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -525,6 +525,10 @@ type testRoomserverAPI struct {
|
||||||
events map[string]*types.HeaderedEvent
|
events map[string]*types.HeaderedEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
|
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
|
||||||
for _, eventID := range req.EventIDs {
|
for _, eventID := range req.EventIDs {
|
||||||
ev := r.events[eventID]
|
ev := r.events[eventID]
|
||||||
|
|
|
||||||
|
|
@ -193,14 +193,20 @@ func Context(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll)
|
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll)
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
|
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
|
|
||||||
newState := state
|
newState := state
|
||||||
if filter.LazyLoadMembers {
|
if filter.LazyLoadMembers {
|
||||||
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
|
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
|
||||||
allEvents = append(allEvents, &requestedEvent)
|
allEvents = append(allEvents, &requestedEvent)
|
||||||
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll)
|
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
|
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("unable to load membership events")
|
logrus.WithError(err).Error("unable to load membership events")
|
||||||
|
|
@ -211,12 +217,16 @@ func Context(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll)
|
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
response := ContextRespsonse{
|
response := ContextRespsonse{
|
||||||
Event: &ev,
|
Event: &ev,
|
||||||
EventsAfter: eventsAfterClient,
|
EventsAfter: eventsAfterClient,
|
||||||
EventsBefore: eventsBeforeClient,
|
EventsBefore: eventsBeforeClient,
|
||||||
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll),
|
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(response.State) > filter.Limit {
|
if len(response.State) > filter.Limit {
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,8 @@ func GetEvent(
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll),
|
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -153,6 +153,8 @@ func GetMemberships(
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll)},
|
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
|
||||||
|
})},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,7 @@ func OnIncomingMessagesRequest(
|
||||||
device: device,
|
device: device,
|
||||||
}
|
}
|
||||||
|
|
||||||
clientEvents, start, end, err := mReq.retrieveEvents()
|
clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed")
|
util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed")
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
@ -273,7 +273,9 @@ func OnIncomingMessagesRequest(
|
||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...)
|
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
|
||||||
|
})...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we didn't return any events, set the end to an empty string, so it will be omitted
|
// If we didn't return any events, set the end to an empty string, so it will be omitted
|
||||||
|
|
@ -310,7 +312,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.
|
||||||
// homeserver in the room for older events.
|
// homeserver in the room for older events.
|
||||||
// Returns an error if there was an issue talking to the database or with the
|
// Returns an error if there was an issue talking to the database or with the
|
||||||
// remote homeserver.
|
// remote homeserver.
|
||||||
func (r *messagesReq) retrieveEvents() (
|
func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserverAPI) (
|
||||||
clientEvents []synctypes.ClientEvent, start,
|
clientEvents []synctypes.ClientEvent, start,
|
||||||
end types.TopologyToken, err error,
|
end types.TopologyToken, err error,
|
||||||
) {
|
) {
|
||||||
|
|
@ -382,7 +384,9 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
"events_before": len(events),
|
"events_before": len(events),
|
||||||
"events_after": len(filteredEvents),
|
"events_after": len(filteredEvents),
|
||||||
}).Debug("applied history visibility (messages)")
|
}).Debug("applied history visibility (messages)")
|
||||||
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll), start, end, err
|
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}), start, end, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) {
|
func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) {
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,9 @@ func Relations(
|
||||||
for _, event := range filteredEvents {
|
for _, event := range filteredEvents {
|
||||||
res.Chunk = append(
|
res.Chunk = append(
|
||||||
res.Chunk,
|
res.Chunk,
|
||||||
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll),
|
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ func Setup(
|
||||||
nb := req.FormValue("next_batch")
|
nb := req.FormValue("next_batch")
|
||||||
nextBatch = &nb
|
nextBatch = &nb
|
||||||
}
|
}
|
||||||
return Search(req, device, syncDB, fts, nextBatch)
|
return Search(req, device, syncDB, fts, nextBatch, rsAPI)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
"github.com/matrix-org/dendrite/internal/fulltext"
|
"github.com/matrix-org/dendrite/internal/fulltext"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
||||||
|
|
@ -38,7 +39,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse {
|
func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string, rsAPI roomserverAPI.SyncRoomserverAPI) util.JSONResponse {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var (
|
var (
|
||||||
searchReq SearchRequest
|
searchReq SearchRequest
|
||||||
|
|
@ -225,14 +226,20 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
||||||
|
|
||||||
results = append(results, Result{
|
results = append(results, Result{
|
||||||
Context: SearchContextResponse{
|
Context: SearchContextResponse{
|
||||||
Start: startToken.String(),
|
Start: startToken.String(),
|
||||||
End: endToken.String(),
|
End: endToken.String(),
|
||||||
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync),
|
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync),
|
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
|
||||||
ProfileInfo: profileInfos,
|
}),
|
||||||
|
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
|
ProfileInfo: profileInfos,
|
||||||
},
|
},
|
||||||
Rank: eventScore[event.EventID()].Score,
|
Rank: eventScore[event.EventID()].Score,
|
||||||
Result: synctypes.ToClientEvent(event, synctypes.FormatAll),
|
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
roomGroup := groups[event.RoomID()]
|
roomGroup := groups[event.RoomID()]
|
||||||
roomGroup.Results = append(roomGroup.Results, event.EventID())
|
roomGroup.Results = append(roomGroup.Results, event.EventID())
|
||||||
|
|
@ -247,7 +254,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync)
|
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
@ -9,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/fulltext"
|
"github.com/matrix-org/dendrite/internal/fulltext"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
rstypes "github.com/matrix-org/dendrite/roomserver/types"
|
rstypes "github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
||||||
|
|
@ -21,6 +23,12 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
|
||||||
|
|
||||||
|
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSearch(t *testing.T) {
|
func TestSearch(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceDevice := userapi.Device{UserID: alice.ID}
|
aliceDevice := userapi.Device{UserID: alice.ID}
|
||||||
|
|
@ -247,7 +255,7 @@ func TestSearch(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/", reqBody)
|
req := httptest.NewRequest(http.MethodPost, "/", reqBody)
|
||||||
|
|
||||||
res := Search(req, tc.device, db, fts, tc.from)
|
res := Search(req, tc.device, db, fts, tc.from, &FakeSyncRoomserverAPI{})
|
||||||
if !tc.wantOK && !res.Is2xx() {
|
if !tc.wantOK && !res.Is2xx() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
|
@ -17,6 +18,7 @@ import (
|
||||||
|
|
||||||
type InviteStreamProvider struct {
|
type InviteStreamProvider struct {
|
||||||
DefaultStreamProvider
|
DefaultStreamProvider
|
||||||
|
rsAPI api.SyncRoomserverAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *InviteStreamProvider) Setup(
|
func (p *InviteStreamProvider) Setup(
|
||||||
|
|
@ -66,7 +68,9 @@ func (p *InviteStreamProvider) IncrementalSync(
|
||||||
if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok {
|
if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ir := types.NewInviteResponse(inviteEvent)
|
ir := types.NewInviteResponse(inviteEvent, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
req.Response.Rooms.Invite[roomID] = ir
|
req.Response.Rooms.Invite[roomID] = ir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -376,20 +376,28 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jr.Timeline.PrevBatch = &prevBatch
|
jr.Timeline.PrevBatch = &prevBatch
|
||||||
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync)
|
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
// If we are limited by the filter AND the history visibility filter
|
// If we are limited by the filter AND the history visibility filter
|
||||||
// didn't "remove" events, return that the response is limited.
|
// didn't "remove" events, return that the response is limited.
|
||||||
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
|
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
|
||||||
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync)
|
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
req.Response.Rooms.Join[delta.RoomID] = jr
|
req.Response.Rooms.Join[delta.RoomID] = jr
|
||||||
|
|
||||||
case spec.Peek:
|
case spec.Peek:
|
||||||
jr := types.NewJoinResponse()
|
jr := types.NewJoinResponse()
|
||||||
jr.Timeline.PrevBatch = &prevBatch
|
jr.Timeline.PrevBatch = &prevBatch
|
||||||
// TODO: Apply history visibility on peeked rooms
|
// TODO: Apply history visibility on peeked rooms
|
||||||
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync)
|
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
jr.Timeline.Limited = limited
|
jr.Timeline.Limited = limited
|
||||||
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync)
|
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
req.Response.Rooms.Peek[delta.RoomID] = jr
|
req.Response.Rooms.Peek[delta.RoomID] = jr
|
||||||
|
|
||||||
case spec.Leave:
|
case spec.Leave:
|
||||||
|
|
@ -398,11 +406,15 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
case spec.Ban:
|
case spec.Ban:
|
||||||
lr := types.NewLeaveResponse()
|
lr := types.NewLeaveResponse()
|
||||||
lr.Timeline.PrevBatch = &prevBatch
|
lr.Timeline.PrevBatch = &prevBatch
|
||||||
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync)
|
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
// If we are limited by the filter AND the history visibility filter
|
// If we are limited by the filter AND the history visibility filter
|
||||||
// didn't "remove" events, return that the response is limited.
|
// didn't "remove" events, return that the response is limited.
|
||||||
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
|
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
|
||||||
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync)
|
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
req.Response.Rooms.Leave[delta.RoomID] = lr
|
req.Response.Rooms.Leave[delta.RoomID] = lr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -552,11 +564,15 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
}
|
}
|
||||||
|
|
||||||
jr.Timeline.PrevBatch = prevBatch
|
jr.Timeline.PrevBatch = prevBatch
|
||||||
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync)
|
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
// If we are limited by the filter AND the history visibility filter
|
// If we are limited by the filter AND the history visibility filter
|
||||||
// didn't "remove" events, return that the response is limited.
|
// didn't "remove" events, return that the response is limited.
|
||||||
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
|
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
|
||||||
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync)
|
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
return jr, nil
|
return jr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ func NewSyncStreamProviders(
|
||||||
},
|
},
|
||||||
InviteStreamProvider: &InviteStreamProvider{
|
InviteStreamProvider: &InviteStreamProvider{
|
||||||
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
|
rsAPI: rsAPI,
|
||||||
},
|
},
|
||||||
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
|
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
|
||||||
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
DefaultStreamProvider: DefaultStreamProvider{DB: d},
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,10 @@ type syncRoomserverAPI struct {
|
||||||
rooms []*test.Room
|
rooms []*test.Room
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {
|
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {
|
||||||
var room *test.Room
|
var room *test.Room
|
||||||
for _, r := range s.rooms {
|
for _, r := range s.rooms {
|
||||||
|
|
|
||||||
|
|
@ -44,21 +44,21 @@ type ClientEvent struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToClientEvents converts server events to client events.
|
// ToClientEvents converts server events to client events.
|
||||||
func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) []ClientEvent {
|
func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) []ClientEvent {
|
||||||
evs := make([]ClientEvent, 0, len(serverEvs))
|
evs := make([]ClientEvent, 0, len(serverEvs))
|
||||||
for _, se := range serverEvs {
|
for _, se := range serverEvs {
|
||||||
if se == nil {
|
if se == nil {
|
||||||
continue // TODO: shouldn't happen?
|
continue // TODO: shouldn't happen?
|
||||||
}
|
}
|
||||||
evs = append(evs, ToClientEvent(se, format))
|
evs = append(evs, ToClientEvent(se, format, userIDForSender))
|
||||||
}
|
}
|
||||||
return evs
|
return evs
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToClientEvent converts a single server event to a client event.
|
// ToClientEvent converts a single server event to a client event.
|
||||||
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent {
|
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) ClientEvent {
|
||||||
user := ""
|
user := ""
|
||||||
userID, err := se.UserID()
|
userID, err := userIDForSender(se.RoomID(), se.SenderID())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
user = userID.String()
|
user = userID.String()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,8 +21,13 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func TestToClientEvent(t *testing.T) { // nolint: gocyclo
|
func TestToClientEvent(t *testing.T) { // nolint: gocyclo
|
||||||
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{
|
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{
|
||||||
"type": "m.room.name",
|
"type": "m.room.name",
|
||||||
|
|
@ -43,7 +48,7 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create Event: %s", err)
|
t.Fatalf("failed to create Event: %s", err)
|
||||||
}
|
}
|
||||||
ce := ToClientEvent(ev, FormatAll)
|
ce := ToClientEvent(ev, FormatAll, UserIDForSender)
|
||||||
if ce.EventID != ev.EventID() {
|
if ce.EventID != ev.EventID() {
|
||||||
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
|
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
|
||||||
}
|
}
|
||||||
|
|
@ -63,7 +68,7 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
|
||||||
t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned))
|
t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned))
|
||||||
}
|
}
|
||||||
user := ""
|
user := ""
|
||||||
userID, err := ev.UserID()
|
userID, err := UserIDForSender("", ev.SenderID())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
user = userID.String()
|
user = userID.String()
|
||||||
}
|
}
|
||||||
|
|
@ -103,7 +108,7 @@ func TestToClientFormatSync(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create Event: %s", err)
|
t.Fatalf("failed to create Event: %s", err)
|
||||||
}
|
}
|
||||||
ce := ToClientEvent(ev, FormatSync)
|
ce := ToClientEvent(ev, FormatSync, UserIDForSender)
|
||||||
if ce.RoomID != "" {
|
if ce.RoomID != "" {
|
||||||
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
|
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -539,7 +539,7 @@ type InviteResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInviteResponse creates an empty response with initialised arrays.
|
// NewInviteResponse creates an empty response with initialised arrays.
|
||||||
func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse {
|
func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDForSender) *InviteResponse {
|
||||||
res := InviteResponse{}
|
res := InviteResponse{}
|
||||||
res.InviteState.Events = []json.RawMessage{}
|
res.InviteState.Events = []json.RawMessage{}
|
||||||
|
|
||||||
|
|
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse {
|
||||||
|
|
||||||
// Then we'll see if we can create a partial of the invite event itself.
|
// Then we'll see if we can create a partial of the invite event itself.
|
||||||
// This is needed for clients to work out *who* sent the invite.
|
// This is needed for clients to work out *who* sent the invite.
|
||||||
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync)
|
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userIDForSender)
|
||||||
inviteEvent.Unsigned = nil
|
inviteEvent.Unsigned = nil
|
||||||
if ev, err := json.Marshal(inviteEvent); err == nil {
|
if ev, err := json.Marshal(inviteEvent); err == nil {
|
||||||
res.InviteState.Events = append(res.InviteState.Events, ev)
|
res.InviteState.Events = append(res.InviteState.Events, ev)
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,13 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSyncTokens(t *testing.T) {
|
func TestSyncTokens(t *testing.T) {
|
||||||
shouldPass := map[string]string{
|
shouldPass := map[string]string{
|
||||||
"s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(),
|
"s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(),
|
||||||
|
|
@ -56,7 +61,7 @@ func TestNewInviteResponse(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev})
|
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, UserIDForSender)
|
||||||
j, err := json.Marshal(res)
|
j, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,10 @@ var (
|
||||||
roomIDCounter = int64(0)
|
roomIDCounter = int64(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
type Room struct {
|
type Room struct {
|
||||||
ID string
|
ID string
|
||||||
Version gomatrixserverlib.RoomVersion
|
Version gomatrixserverlib.RoomVersion
|
||||||
|
|
@ -195,7 +199,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
|
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
|
||||||
}
|
}
|
||||||
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
|
if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil {
|
||||||
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
|
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
|
||||||
}
|
}
|
||||||
headeredEvent := &rstypes.HeaderedEvent{PDU: ev}
|
headeredEvent := &rstypes.HeaderedEvent{PDU: ev}
|
||||||
|
|
|
||||||
|
|
@ -301,7 +301,9 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case event.Type() == spec.MRoomMember:
|
case event.Type() == spec.MRoomMember:
|
||||||
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll)
|
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
var member *localMembership
|
var member *localMembership
|
||||||
member, err = newLocalMembership(&cevent)
|
member, err = newLocalMembership(&cevent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -534,7 +536,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
|
||||||
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
|
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
|
||||||
// fields seem to match. room_id should be missing, which
|
// fields seem to match. room_id should be missing, which
|
||||||
// matches the behaviour of FormatSync.
|
// matches the behaviour of FormatSync.
|
||||||
Event: synctypes.ToClientEvent(event, synctypes.FormatSync),
|
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
}),
|
||||||
// TODO: this is per-device, but it's not part of the primary
|
// TODO: this is per-device, but it's not part of the primary
|
||||||
// key. So inserting one notification per profile tag doesn't
|
// key. So inserting one notification per profile tag doesn't
|
||||||
// make sense. What is this supposed to be? Sytests require it
|
// make sense. What is this supposed to be? Sytests require it
|
||||||
|
|
@ -616,7 +620,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
|
||||||
// user. Returns actions (including dont_notify).
|
// user. Returns actions (including dont_notify).
|
||||||
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
|
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
|
||||||
user := ""
|
user := ""
|
||||||
userID, err := event.UserID()
|
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
user = userID.String()
|
user = userID.String()
|
||||||
}
|
}
|
||||||
|
|
@ -655,7 +659,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
||||||
roomSize: roomSize,
|
roomSize: roomSize,
|
||||||
}
|
}
|
||||||
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
|
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
|
||||||
rule, err := eval.MatchEvent(event.PDU)
|
rule, err := eval.MatchEvent(event.PDU, func(roomAliasOrID, senderID string) (*spec.UserID, error) {
|
||||||
|
return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
|
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage"
|
"github.com/matrix-org/dendrite/userapi/storage"
|
||||||
|
|
@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
|
||||||
return &types.HeaderedEvent{PDU: ev}
|
return &types.HeaderedEvent{PDU: ev}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
|
||||||
|
|
||||||
|
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_evaluatePushRules(t *testing.T) {
|
func Test_evaluatePushRules(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
consumer := OutputRoomEventConsumer{db: db}
|
consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
|
@ -22,6 +23,10 @@ import (
|
||||||
userUtil "github.com/matrix-org/dendrite/userapi/util"
|
userUtil "github.com/matrix-org/dendrite/userapi/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) {
|
||||||
|
return spec.NewUserID(senderID, true)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNotifyUserCountsAsync(t *testing.T) {
|
func TestNotifyUserCountsAsync(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
|
@ -100,7 +105,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
|
||||||
|
|
||||||
// Insert a dummy event
|
// Insert a dummy event
|
||||||
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
|
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
|
||||||
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll),
|
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, UserIDForSender),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue