Use spec.RoomID for userID query interface

This commit is contained in:
Devon Hudson 2023-06-13 14:39:54 +01:00
parent 66a29cb01c
commit 50a783f0eb
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
57 changed files with 326 additions and 157 deletions

View file

@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) sendEvents(
// Create the transaction body.
transaction, err := json.Marshal(
ApplicationServiceTransaction{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
},
@ -236,7 +236,11 @@ func (s *appserviceState) backoffAndPause(err error) error {
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
user := ""
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return false
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil {
user = userID.String()
}

View file

@ -235,10 +235,9 @@ func RemoveLocalAlias(
validRoomID, err := spec.NewRoomID(roomIDRes.RoomID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
Code: http.StatusNotFound,
JSON: spec.NotFound("The alias does not exist."),
}
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID)

View file

@ -351,8 +351,8 @@ func generateSendEvent(
stateEvents[i] = queryRes.StateEvents[i].PDU
}
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID)
}); err != nil {
return nil, &util.JSONResponse{
Code: http.StatusForbidden,

View file

@ -150,7 +150,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
for _, ev := range stateRes.StateEvents {
stateEvents = append(
stateEvents,
synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, ev),
)
@ -173,14 +173,19 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
}
for _, ev := range stateAfterRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
evRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid")
continue
}
userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := ev.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
@ -367,7 +372,7 @@ func OnIncomingStateTypeRequest(
}
stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
ClientEvent: synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, event),
}

View file

@ -189,7 +189,7 @@ func main() {
fmt.Println("Resolving state")
var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
},
)

View file

@ -36,7 +36,7 @@ type fedRoomserverAPI struct {
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
}
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -166,10 +166,10 @@ func (r *FederationInternalAPI) performJoinUsingServer(
PrivateKey: r.cfg.Matrix.PrivateKey,
KeyID: r.cfg.Matrix.KeyID,
KeyRing: r.keyRing,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
},
}
@ -365,7 +365,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse()
userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
userIDProvider := func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
@ -528,7 +528,11 @@ func (r *FederationInternalAPI) SendInvite(
event gomatrixserverlib.PDU,
strippedState []gomatrixserverlib.InviteStrippedState,
) (gomatrixserverlib.PDU, error) {
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil {
return nil, err
}

View file

@ -95,7 +95,7 @@ func InviteV2(
StateQuerier: rsAPI.StateQuerier(),
InviteEvent: inviteReq.Event(),
StrippedState: inviteReq.InviteRoomState(),
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
}
@ -188,7 +188,7 @@ func InviteV1(
StateQuerier: rsAPI.StateQuerier(),
InviteEvent: event,
StrippedState: strippedState,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
}

View file

@ -118,7 +118,7 @@ func MakeJoin(
LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom,
RoomQuerier: &roomQuerier,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
BuildEventTemplate: createJoinTemplate,
@ -215,7 +215,7 @@ func SendJoin(
PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys,
MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI},
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
}

View file

@ -105,7 +105,7 @@ func MakeLeave(
LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom,
BuildEventTemplate: createLeaveTemplate,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
}
@ -236,7 +236,14 @@ func SendLeave(
// Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both.
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Room ID is invalid."),
}
}
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID())
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,

View file

@ -140,7 +140,14 @@ func ExchangeThirdPartyInvite(
}
}
userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID))
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid room ID"),
}
}
userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(proto.SenderID))
if err != nil || userID == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
@ -150,7 +157,7 @@ func ExchangeThirdPartyInvite(
senderDomain := userID.Domain()
// Check that the state key is correct.
targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey))
targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(*proto.StateKey))
if err != nil || targetUserID == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,

10
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077
github.com/matrix-org/gomatrixserverlib v0.0.0-20230613121253-ff6a04adf44e
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16
@ -42,11 +42,11 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.6
go.uber.org/atomic v1.10.0
golang.org/x/crypto v0.9.0
golang.org/x/crypto v0.10.0
golang.org/x/image v0.5.0
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
golang.org/x/sync v0.1.0
golang.org/x/term v0.8.0
golang.org/x/term v0.9.0
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0
@ -127,8 +127,8 @@ require (
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/sys v0.9.0 // indirect
golang.org/x/text v0.10.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.6.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect

20
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230613121253-ff6a04adf44e h1:jS+UZAfaTiyClN/l5LKbC+alXJ94/SDXFWByUgzBCYA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230613121253-ff6a04adf44e/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
@ -511,8 +511,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -669,12 +669,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28=
golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -683,8 +683,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View file

@ -115,7 +115,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
case SenderKind:
userID := ""
sender, err := userIDForSender(event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return false, err
}
sender, err := userIDForSender(*validRoomID, event.SenderID())
if err == nil {
userID = sender.String()
}

View file

@ -8,7 +8,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec"
)
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
@ -73,7 +73,7 @@ func TestRuleMatches(t *testing.T) {
{"emptyOverride", OverrideKind, emptyRule, `{}`, true},
{"emptyContent", ContentKind, emptyRule, `{}`, false},
{"emptyRoom", RoomKind, emptyRule, `{}`, true},
{"emptySender", SenderKind, emptyRule, `{}`, true},
{"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true},
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
{"disabled", OverrideKind, Rule{}, `{}`, false},
@ -90,8 +90,8 @@ func TestRuleMatches(t *testing.T) {
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true},
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false},
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true},
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false},
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true},
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {

View file

@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
}
continue
}
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())

View file

@ -70,7 +70,7 @@ type FakeRsAPI struct {
bannedFromRoom bool
}
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
@ -642,7 +642,7 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
}
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -78,7 +78,7 @@ type InputRoomEventsAPI interface {
type QuerySenderIDAPI interface {
QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error)
}
// Query the latest events and state for a room from the room server.

View file

@ -89,7 +89,11 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySende
continue
}
userID, err := querier.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
continue
}
userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
if err != nil {
continue
}

View file

@ -14,7 +14,7 @@ type FakeQuerier struct {
api.QuerySenderIDAPI
}
func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -113,6 +113,7 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID(
return nil
}
// nolint:gocyclo
// RemoveRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) RemoveRoomAlias(
ctx context.Context,
@ -129,7 +130,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return nil
}
sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID)
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return err
}
sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID)
if err != nil || sender == nil {
return fmt.Errorf("r.QueryUserIDForSender: %w", err)
}
@ -177,7 +183,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
if request.SenderID != ev.SenderID() {
senderID = ev.SenderID()
}
sender, err := r.QueryUserIDForSender(ctx, roomID, senderID)
sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
if err != nil || sender == nil {
return err
}

View file

@ -78,7 +78,7 @@ func CheckForSoftFail(
}
// Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return querier.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
// return true, nil

View file

@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState(
// If the event state key doesn't match the given servername
// then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc.
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
for nid, key := range stateKeys {
if key != "" {
userID, err := querier.QueryUserIDForSender(ctx, roomID, spec.SenderID(key))
userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key))
if err == nil && userID != nil {
if userID.Domain() != serverName {
delete(stateKeys, nid)

View file

@ -128,7 +128,11 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
}
sender, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil {
return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
}
@ -282,7 +286,7 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted.
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
isRejected = true
@ -587,7 +591,7 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent),
)
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
@ -700,7 +704,7 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue nextAuthEvent
@ -718,7 +722,7 @@ nextAuthEvent:
}
// Check if the auth event should be rejected.
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
})
if isRejected = err != nil; isRejected {
@ -842,7 +846,11 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
continue
}
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey()))
validRoomID, err := spec.NewRoomID(memberEvent.RoomID())
if err != nil {
continue
}
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey()))
if err != nil {
continue
}

View file

@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
}
// Finally check that the event is NOT allowed
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}); err == nil {
t.Fatalf("event should not be allowed, but it was")

View file

@ -139,7 +139,11 @@ func (r *Inputer) updateMembership(
func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil {
userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey))
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return isTargetLocalUser
}
userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey))
if err != nil || userID == nil {
return isTargetLocalUser
}

View file

@ -473,7 +473,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...)
}
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
},
)
@ -482,7 +482,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
}
// apply the current event
retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
switch missing := err.(type) {
@ -569,7 +569,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue
@ -660,7 +660,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(),
}, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
}, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
@ -897,7 +897,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
}
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())

View file

@ -74,6 +74,10 @@ func (r *Admin) PerformAdminEvacuateRoom(
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
return nil, err
}
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
prevEvents := latestRes.LatestEvents
var senderDomain spec.ServerName
@ -100,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
PrevEvents: prevEvents,
}
userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID))
userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID))
if err != nil || userID == nil {
continue
}
@ -264,7 +268,7 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
}
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue
@ -272,7 +276,7 @@ func (r *Admin) PerformAdminDownloadState(
authEventMap[authEvent.EventID()] = authEvent
}
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue

View file

@ -122,7 +122,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
},
)
@ -213,7 +213,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue
}
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
@ -492,7 +492,11 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents {
if sender, err := b.querier.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
continue
}
if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil {
serverSet[sender.Domain()] = true
}
}

View file

@ -322,7 +322,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")

View file

@ -99,7 +99,11 @@ func (r *Inviter) ProcessInviteMembership(
var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater
userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
if err != nil {
return nil, err
}
userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
}
@ -127,7 +131,12 @@ func (r *Inviter) PerformInvite(
) error {
event := req.Event
sender, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil {
return spec.InvalidParam("The sender user ID is invalid")
}
@ -138,17 +147,12 @@ func (r *Inviter) PerformInvite(
if event.StateKey() == nil || *event.StateKey() == "" {
return fmt.Errorf("invite must be a state event")
}
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if err != nil || invitedUser == nil {
return spec.InvalidParam("Could not find the matching senderID for this user")
}
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser)
if err != nil {
return fmt.Errorf("failed looking up senderID for invited user")
@ -163,7 +167,7 @@ func (r *Inviter) PerformInvite(
StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB, r.RSAPI},
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
},
}

View file

@ -215,7 +215,7 @@ func (r *Joiner) performJoinRoomByID(
// and we aren't in the room.
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
if err == nil && !serverInRoom && isInvitePending {
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender)
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
if queryErr != nil {
return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
}

View file

@ -91,7 +91,7 @@ func (r *Leaver) performLeaveRoomByID(
// that.
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
if err == nil && isInvitePending {
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
if serr != nil || sender == nil {
return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
}

View file

@ -492,7 +492,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send
}
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
@ -573,7 +573,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send
stateEvents[i] = queryRes.StateEvents[i].PDU
}
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?

View file

@ -159,7 +159,7 @@ func (r *Queryer) QueryStateAfterEvents(
}
stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID)
},
)
@ -407,7 +407,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
@ -458,7 +458,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
@ -651,7 +651,7 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID)
},
)
@ -1007,10 +1007,11 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID,
}
}
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
userID, err := spec.NewUserID(string(senderID), true)
if err == nil {
return userID, nil
}
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
// TODO: pseudoIDs
return r.DB.GetUserIDForSender(ctx, roomID.String(), senderID)
}

View file

@ -949,7 +949,7 @@ func (v *StateResolution) resolveConflictsV1(
}
// Resolve the conflicts.
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
})
@ -1063,7 +1063,7 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents,
nonConflictedEvents,
authEvents,
func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
},
)

View file

@ -94,7 +94,7 @@ type MSC2836EventRelationshipsResponse struct {
func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
out := &EventRelationshipResponse{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
Limited: res.Limited,

View file

@ -525,7 +525,7 @@ type testRoomserverAPI struct {
events map[string]*types.HeaderedEvent
}
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -377,7 +377,11 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
return sp, fmt.Errorf("unexpected nil state_key")
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return sp, err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey()))
if err != nil || userID == nil {
return sp, fmt.Errorf("failed getting userID for sender: %w", err)
}
@ -404,7 +408,11 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
return
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
validRoomID, err := spec.NewRoomID(msg.Event.RoomID())
if err != nil {
return
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey()))
if err != nil || userID == nil {
return
}
@ -454,7 +462,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos)
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
validRoomID, err := spec.NewRoomID(msg.RoomID)
if err != nil {
log.WithFields(log.Fields{
"event_id": msg.EventID,
"room_id": msg.RoomID,
log.ErrorKey: err,
}).Errorf("roomID is invalid")
return
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, msg.TargetSenderID)
if err != nil || userID == nil {
log.WithFields(log.Fields{
"event_id": msg.EventID,

View file

@ -174,7 +174,11 @@ func TrackChangedUsers(
if membership != spec.Join {
continue
}
user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
validRoomID, roomErr := spec.NewRoomID(roomID)
if roomErr != nil {
continue
}
user, queryErr := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey))
if queryErr != nil || user == nil {
continue
}
@ -222,7 +226,11 @@ func TrackChangedUsers(
}
// new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
continue
}
user, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey))
if err != nil || user == nil {
continue
}

View file

@ -64,7 +64,7 @@ type mockRoomserverAPI struct {
roomIDToJoinedMembers map[string][]string
}
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -101,13 +101,20 @@ func (n *Notifier) OnNewEvent(
n._removeEmptyUserStreams()
if ev != nil {
validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: RoomID is invalid",
)
return
}
// Map this event's room_id to a list of joined users, and wake them up.
usersToNotify := n._joinedUsers(ev.RoomID())
// Map this event's room_id to a list of peeking devices, and wake them up.
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey()))
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: Failed to find the userID for this event",

View file

@ -109,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
type TestRoomServer struct{ api.SyncRoomserverAPI }
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -200,10 +200,10 @@ func Context(
}
}
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
@ -211,7 +211,7 @@ func Context(
if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent)
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
@ -224,14 +224,14 @@ func Context(
}
}
ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
ev := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, requestedEvent)
response := ContextRespsonse{
Event: &ev,
EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient,
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
}

View file

@ -102,14 +102,28 @@ func GetEvent(
}
sender := spec.UserID{}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID())
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("roomID is invalid"),
}
}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID())
if err == nil && senderUserID != nil {
sender = *senderUserID
}
sk := events[0].StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
evRoomID, err := spec.NewRoomID(events[0].RoomID())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("roomID is invalid"),
}
}
skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*events[0].StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString

View file

@ -152,7 +152,15 @@ func GetMemberships(
}
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID())
if err != nil || userID == nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
return util.JSONResponse{
@ -175,7 +183,7 @@ func GetMemberships(
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})},
}

View file

@ -273,7 +273,7 @@ func OnIncomingMessagesRequest(
JSON: spec.InternalServerError{},
}
}
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})...)
}
@ -389,7 +389,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
"events_before": len(events),
"events_after": len(filteredEvents),
}).Debug("applied history visibility (messages)")
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), start, end, err
}

View file

@ -115,14 +115,18 @@ func Relations(
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
continue
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString

View file

@ -205,9 +205,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos := make(map[string]ProfileInfoResponse)
for _, ev := range append(eventsBefore, eventsAfter...) {
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
validRoomID, roomErr := spec.NewRoomID(ev.RoomID())
if err != nil {
logrus.WithError(roomErr).WithField("room_id", ev.RoomID()).Warn("failed to query userprofile")
continue
}
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID())
if queryErr != nil {
logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
logrus.WithError(queryErr).WithField("sender_id", ev.SenderID()).Warn("failed to query userprofile")
continue
}
@ -231,14 +236,19 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
}
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
validRoomID, roomErr := spec.NewRoomID(event.RoomID())
if err != nil {
logrus.WithError(roomErr).WithField("room_id", event.RoomID()).Warn("failed to query userprofile")
continue
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
@ -248,10 +258,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
Context: SearchContextResponse{
Start: startToken.String(),
End: endToken.String(),
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
ProfileInfo: profileInfos,
@ -272,7 +282,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
JSON: spec.InternalServerError{},
}
}
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})
}

View file

@ -25,7 +25,7 @@ import (
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -65,14 +65,18 @@ func (p *InviteStreamProvider) IncrementalSync(
for roomID, inviteEvent := range invites {
user := spec.UserID{}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID())
validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
if err != nil {
continue
}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, inviteEvent.SenderID())
if err == nil && sender != nil {
user = *sender
}
sk := inviteEvent.StateKey()
if sk != nil && *sk != "" {
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString

View file

@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
}
jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
req.Response.Rooms.Join[delta.RoomID] = jr
@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
jr.Timeline.Limited = limited
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
req.Response.Rooms.Peek[delta.RoomID] = jr
@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case spec.Ban:
lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
req.Response.Rooms.Leave[delta.RoomID] = lr
@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
return jr, nil
@ -585,6 +585,10 @@ func (p *PDUStreamProvider) lazyLoadMembers(
if len(timelineEvents) == 0 {
return stateEvents, nil
}
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
// Work out which memberships to include
timelineUsers := make(map[string]struct{})
if !incremental {
@ -606,8 +610,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list
userID := ""
stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey()))
if err == nil && stateKeyUserID != nil {
stateKeyUserID, queryErr := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if queryErr == nil && stateKeyUserID != nil {
userID = stateKeyUserID.String()
}
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {

View file

@ -40,7 +40,7 @@ type syncRoomserverAPI struct {
rooms []*test.Room
}
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -52,14 +52,18 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
continue // TODO: shouldn't happen?
}
sender := spec.UserID{}
userID, err := userIDForSender(se.RoomID(), se.SenderID())
validRoomID, err := spec.NewRoomID(se.RoomID())
if err != nil {
continue
}
userID, err := userIDForSender(*validRoomID, se.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := se.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
@ -95,14 +99,18 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
sender := spec.UserID{}
userID, err := userIDQuery(event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return ClientEvent{}
}
userID, err := userIDQuery(*validRoomID, event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
skUserID, err := userIDQuery(*validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString

View file

@ -39,7 +39,7 @@ var (
roomIDCounter = int64(0)
)
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}

View file

@ -302,14 +302,18 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch {
case event.Type() == spec.MRoomMember:
sender := spec.UserID{}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
validRoomID, roomErr := spec.NewRoomID(event.RoomID())
if roomErr != nil {
return roomErr
}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if queryErr == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if queryErr == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
@ -544,14 +548,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
}
sender := spec.UserID{}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if queryErr == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
@ -644,7 +652,11 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// user. Returns actions (including dont_notify).
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
user := ""
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil {
user = sender.String()
}
@ -682,7 +694,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize,
}
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
rule, err := eval.MatchEvent(event.PDU, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
@ -790,7 +802,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
}
default:
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil {
logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
return nil, err

View file

@ -47,7 +47,7 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
@ -68,13 +68,13 @@ func Test_evaluatePushRules(t *testing.T) {
}{
{
name: "m.receipt doesn't notify",
eventContent: `{"type":"m.receipt"}`,
eventContent: `{"type":"m.receipt","room_id":"!room:example.com"}`,
wantAction: pushrules.UnknownAction,
wantActions: nil,
},
{
name: "m.reaction doesn't notify",
eventContent: `{"type":"m.reaction"}`,
eventContent: `{"type":"m.reaction","room_id":"!room:example.com"}`,
wantAction: pushrules.DontNotifyAction,
wantActions: []*pushrules.Action{
{
@ -84,7 +84,7 @@ func Test_evaluatePushRules(t *testing.T) {
},
{
name: "m.room.message notifies",
eventContent: `{"type":"m.room.message"}`,
eventContent: `{"type":"m.room.message","room_id":"!room:example.com"}`,
wantNotify: true,
wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{
@ -93,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) {
},
{
name: "m.room.message highlights",
eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`,
eventContent: `{"type":"m.room.message", "content": {"body": "test"},"room_id":"!room:example.com"}`,
wantNotify: true,
wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{