Use spec.RoomID for userID query interface

This commit is contained in:
Devon Hudson 2023-06-13 14:39:54 +01:00
parent 66a29cb01c
commit 50a783f0eb
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
57 changed files with 326 additions and 157 deletions

View file

@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) sendEvents(
// Create the transaction body. // Create the transaction body.
transaction, err := json.Marshal( transaction, err := json.Marshal(
ApplicationServiceTransaction{ ApplicationServiceTransaction{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
}, },
@ -236,7 +236,11 @@ func (s *appserviceState) backoffAndPause(err error) error {
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
user := "" user := ""
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return false
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil { if err == nil {
user = userID.String() user = userID.String()
} }

View file

@ -235,10 +235,9 @@ func RemoveLocalAlias(
validRoomID, err := spec.NewRoomID(roomIDRes.RoomID) validRoomID, err := spec.NewRoomID(roomIDRes.RoomID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusNotFound,
JSON: spec.InternalServerError{}, JSON: spec.NotFound("The alias does not exist."),
} }
} }
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID) deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID)

View file

@ -351,8 +351,8 @@ func generateSendEvent(
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID)
}); err != nil { }); err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View file

@ -150,7 +150,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
for _, ev := range stateRes.StateEvents { for _, ev := range stateRes.StateEvents {
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, ev), }, ev),
) )
@ -173,14 +173,19 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
} }
for _, ev := range stateAfterRes.StateEvents { for _, ev := range stateAfterRes.StateEvents {
sender := spec.UserID{} sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) evRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid")
continue
}
userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := ev.StateKey() sk := ev.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -367,7 +372,7 @@ func OnIncomingStateTypeRequest(
} }
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { ClientEvent: synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, event), }, event),
} }

View file

@ -189,7 +189,7 @@ func main() {
fmt.Println("Resolving state") fmt.Println("Resolving state")
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )

View file

@ -36,7 +36,7 @@ type fedRoomserverAPI struct {
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
} }
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -166,10 +166,10 @@ func (r *FederationInternalAPI) performJoinUsingServer(
PrivateKey: r.cfg.Matrix.PrivateKey, PrivateKey: r.cfg.Matrix.PrivateKey,
KeyID: r.cfg.Matrix.KeyID, KeyID: r.cfg.Matrix.KeyID,
KeyRing: r.keyRing, KeyRing: r.keyRing,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
} }
@ -365,7 +365,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { userIDProvider := func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
} }
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
@ -528,7 +528,11 @@ func (r *FederationInternalAPI) SendInvite(
event gomatrixserverlib.PDU, event gomatrixserverlib.PDU,
strippedState []gomatrixserverlib.InviteStrippedState, strippedState []gomatrixserverlib.InviteStrippedState,
) (gomatrixserverlib.PDU, error) { ) (gomatrixserverlib.PDU, error) {
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -95,7 +95,7 @@ func InviteV2(
StateQuerier: rsAPI.StateQuerier(), StateQuerier: rsAPI.StateQuerier(),
InviteEvent: inviteReq.Event(), InviteEvent: inviteReq.Event(),
StrippedState: inviteReq.InviteRoomState(), StrippedState: inviteReq.InviteRoomState(),
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }
@ -188,7 +188,7 @@ func InviteV1(
StateQuerier: rsAPI.StateQuerier(), StateQuerier: rsAPI.StateQuerier(),
InviteEvent: event, InviteEvent: event,
StrippedState: strippedState, StrippedState: strippedState,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }

View file

@ -118,7 +118,7 @@ func MakeJoin(
LocalServerName: cfg.Matrix.ServerName, LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerInRoom: res.RoomExists && res.IsInRoom,
RoomQuerier: &roomQuerier, RoomQuerier: &roomQuerier,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
BuildEventTemplate: createJoinTemplate, BuildEventTemplate: createJoinTemplate,
@ -215,7 +215,7 @@ func SendJoin(
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys, Verifier: keys,
MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI},
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }

View file

@ -105,7 +105,7 @@ func MakeLeave(
LocalServerName: cfg.Matrix.ServerName, LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerInRoom: res.RoomExists && res.IsInRoom,
BuildEventTemplate: createLeaveTemplate, BuildEventTemplate: createLeaveTemplate,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }
@ -236,7 +236,14 @@ func SendLeave(
// Check that the sender belongs to the server that is sending us // Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender // the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both. // and the state key are equal so we don't need to check both.
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Room ID is invalid."),
}
}
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID())
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View file

@ -140,7 +140,14 @@ func ExchangeThirdPartyInvite(
} }
} }
userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID)) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid room ID"),
}
}
userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(proto.SenderID))
if err != nil || userID == nil { if err != nil || userID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -150,7 +157,7 @@ func ExchangeThirdPartyInvite(
senderDomain := userID.Domain() senderDomain := userID.Domain()
// Check that the state key is correct. // Check that the state key is correct.
targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey)) targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(*proto.StateKey))
if err != nil || targetUserID == nil { if err != nil || targetUserID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

10
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 github.com/matrix-org/gomatrixserverlib v0.0.0-20230613121253-ff6a04adf44e
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16
@ -42,11 +42,11 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.6 github.com/yggdrasil-network/yggdrasil-go v0.4.6
go.uber.org/atomic v1.10.0 go.uber.org/atomic v1.10.0
golang.org/x/crypto v0.9.0 golang.org/x/crypto v0.10.0
golang.org/x/image v0.5.0 golang.org/x/image v0.5.0
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
golang.org/x/sync v0.1.0 golang.org/x/sync v0.1.0
golang.org/x/term v0.8.0 golang.org/x/term v0.9.0
gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0 gotest.tools/v3 v3.4.0
@ -127,8 +127,8 @@ require (
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/net v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect golang.org/x/sys v0.9.0 // indirect
golang.org/x/text v0.9.0 // indirect golang.org/x/text v0.10.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.6.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect google.golang.org/protobuf v1.28.1 // indirect

20
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230613121253-ff6a04adf44e h1:jS+UZAfaTiyClN/l5LKbC+alXJ94/SDXFWByUgzBCYA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230613121253-ff6a04adf44e/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
@ -511,8 +511,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -669,12 +669,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -683,8 +683,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View file

@ -115,7 +115,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
case SenderKind: case SenderKind:
userID := "" userID := ""
sender, err := userIDForSender(event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return false, err
}
sender, err := userIDForSender(*validRoomID, event.SenderID())
if err == nil { if err == nil {
userID = sender.String() userID = sender.String()
} }

View file

@ -8,7 +8,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
@ -73,7 +73,7 @@ func TestRuleMatches(t *testing.T) {
{"emptyOverride", OverrideKind, emptyRule, `{}`, true}, {"emptyOverride", OverrideKind, emptyRule, `{}`, true},
{"emptyContent", ContentKind, emptyRule, `{}`, false}, {"emptyContent", ContentKind, emptyRule, `{}`, false},
{"emptyRoom", RoomKind, emptyRule, `{}`, true}, {"emptyRoom", RoomKind, emptyRule, `{}`, true},
{"emptySender", SenderKind, emptyRule, `{}`, true}, {"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true},
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
{"disabled", OverrideKind, Rule{}, `{}`, false}, {"disabled", OverrideKind, Rule{}, `{}`, false},
@ -90,8 +90,8 @@ func TestRuleMatches(t *testing.T) {
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true},
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false},
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true}, {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true},
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false}, {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false},
} }
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {

View file

@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
} }
continue continue
} }
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())

View file

@ -70,7 +70,7 @@ type FakeRsAPI struct {
bannedFromRoom bool bannedFromRoom bool
} }
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
@ -642,7 +642,7 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
} }
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -78,7 +78,7 @@ type InputRoomEventsAPI interface {
type QuerySenderIDAPI interface { type QuerySenderIDAPI interface {
QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error)
} }
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.

View file

@ -89,7 +89,11 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySende
continue continue
} }
userID, err := querier.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
continue
}
userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
if err != nil { if err != nil {
continue continue
} }

View file

@ -14,7 +14,7 @@ type FakeQuerier struct {
api.QuerySenderIDAPI api.QuerySenderIDAPI
} }
func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -113,6 +113,7 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID(
return nil return nil
} }
// nolint:gocyclo
// RemoveRoomAlias implements alias.RoomserverInternalAPI // RemoveRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) RemoveRoomAlias( func (r *RoomserverInternalAPI) RemoveRoomAlias(
ctx context.Context, ctx context.Context,
@ -129,7 +130,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return nil return nil
} }
sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return err
}
sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID)
if err != nil || sender == nil { if err != nil || sender == nil {
return fmt.Errorf("r.QueryUserIDForSender: %w", err) return fmt.Errorf("r.QueryUserIDForSender: %w", err)
} }
@ -177,7 +183,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
if request.SenderID != ev.SenderID() { if request.SenderID != ev.SenderID() {
senderID = ev.SenderID() senderID = ev.SenderID()
} }
sender, err := r.QueryUserIDForSender(ctx, roomID, senderID) sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
if err != nil || sender == nil { if err != nil || sender == nil {
return err return err
} }

View file

@ -78,7 +78,7 @@ func CheckForSoftFail(
} }
// Check if the event is allowed. // Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return querier.QueryUserIDForSender(ctx, roomID, senderID) return querier.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
// return true, nil // return true, nil

View file

@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState(
// If the event state key doesn't match the given servername // If the event state key doesn't match the given servername
// then we'll filter it out. This does preserve state keys that // then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc. // are "" since these will contain history visibility etc.
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
for nid, key := range stateKeys { for nid, key := range stateKeys {
if key != "" { if key != "" {
userID, err := querier.QueryUserIDForSender(ctx, roomID, spec.SenderID(key)) userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key))
if err == nil && userID != nil { if err == nil && userID != nil {
if userID.Domain() != serverName { if userID.Domain() != serverName {
delete(stateKeys, nid) delete(stateKeys, nid)

View file

@ -128,7 +128,11 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent { if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
} }
sender, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
} }
@ -282,7 +286,7 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then // Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted. // we consider the event to be "rejected" — it will still be persisted.
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
isRejected = true isRejected = true
@ -587,7 +591,7 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents( stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent), gomatrixserverlib.ToPDUs(stateBeforeEvent),
) )
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); rejectionErr != nil { }); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
@ -700,7 +704,7 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply // Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem // skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway. // if a critical event is missing anyway.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue nextAuthEvent continue nextAuthEvent
@ -718,7 +722,7 @@ nextAuthEvent:
} }
// Check if the auth event should be rejected. // Check if the auth event should be rejected.
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if isRejected = err != nil; isRejected { if isRejected = err != nil; isRejected {
@ -842,7 +846,11 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
continue continue
} }
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey())) validRoomID, err := spec.NewRoomID(memberEvent.RoomID())
if err != nil {
continue
}
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey()))
if err != nil { if err != nil {
continue continue
} }

View file

@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
} }
// Finally check that the event is NOT allowed // Finally check that the event is NOT allowed
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
}); err == nil { }); err == nil {
t.Fatalf("event should not be allowed, but it was") t.Fatalf("event should not be allowed, but it was")

View file

@ -139,7 +139,11 @@ func (r *Inputer) updateMembership(
func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
isTargetLocalUser := false isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil { if statekey := event.StateKey(); statekey != nil {
userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey)) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return isTargetLocalUser
}
userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey))
if err != nil || userID == nil { if err != nil || userID == nil {
return isTargetLocalUser return isTargetLocalUser
} }

View file

@ -473,7 +473,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...) stateEventList = append(stateEventList, state.StateEvents...)
} }
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
@ -482,7 +482,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
} }
// apply the current event // apply the current event
retryAllowedState: retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
@ -569,7 +569,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed. // will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue
@ -660,7 +660,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(), StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(), AuthEvents: state.GetAuthEvents(),
}, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { }, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {
@ -897,7 +897,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
} }
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())

View file

@ -74,6 +74,10 @@ func (r *Admin) PerformAdminEvacuateRoom(
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
return nil, err return nil, err
} }
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
prevEvents := latestRes.LatestEvents prevEvents := latestRes.LatestEvents
var senderDomain spec.ServerName var senderDomain spec.ServerName
@ -100,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }
userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID)) userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID))
if err != nil || userID == nil { if err != nil || userID == nil {
continue continue
} }
@ -264,7 +268,7 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
} }
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue
@ -272,7 +276,7 @@ func (r *Admin) PerformAdminDownloadState(
authEventMap[authEvent.EventID()] = authEvent authEventMap[authEvent.EventID()] = authEvent
} }
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue

View file

@ -122,7 +122,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events" // Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID) return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
@ -213,7 +213,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue continue
} }
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID) return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {
@ -492,7 +492,11 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates. // Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool) serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents { for _, event := range memberEvents {
if sender, err := b.querier.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
continue
}
if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil {
serverSet[sender.Domain()] = true serverSet[sender.Domain()] = true
} }
} }

View file

@ -322,7 +322,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")

View file

@ -99,7 +99,11 @@ func (r *Inviter) ProcessInviteMembership(
var outputUpdates []api.OutputEvent var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater var updater *shared.MembershipUpdater
userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
if err != nil {
return nil, err
}
userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
if err != nil { if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
} }
@ -127,7 +131,12 @@ func (r *Inviter) PerformInvite(
) error { ) error {
event := req.Event event := req.Event
sender, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return spec.InvalidParam("The sender user ID is invalid") return spec.InvalidParam("The sender user ID is invalid")
} }
@ -138,17 +147,12 @@ func (r *Inviter) PerformInvite(
if event.StateKey() == nil || *event.StateKey() == "" { if event.StateKey() == nil || *event.StateKey() == "" {
return fmt.Errorf("invite must be a state event") return fmt.Errorf("invite must be a state event")
} }
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if err != nil || invitedUser == nil { if err != nil || invitedUser == nil {
return spec.InvalidParam("Could not find the matching senderID for this user") return spec.InvalidParam("Could not find the matching senderID for this user")
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser) invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser)
if err != nil { if err != nil {
return fmt.Errorf("failed looking up senderID for invited user") return fmt.Errorf("failed looking up senderID for invited user")
@ -163,7 +167,7 @@ func (r *Inviter) PerformInvite(
StrippedState: req.InviteRoomState, StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB, r.RSAPI}, StateQuerier: &QueryState{r.DB, r.RSAPI},
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
} }

View file

@ -215,7 +215,7 @@ func (r *Joiner) performJoinRoomByID(
// and we aren't in the room. // and we aren't in the room.
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
if err == nil && !serverInRoom && isInvitePending { if err == nil && !serverInRoom && isInvitePending {
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender) inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
if queryErr != nil { if queryErr != nil {
return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
} }

View file

@ -91,7 +91,7 @@ func (r *Leaver) performLeaveRoomByID(
// that. // that.
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver) isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
if err == nil && isInvitePending { if err == nil && isInvitePending {
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser) sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
if serr != nil || sender == nil { if serr != nil || sender == nil {
return nil, fmt.Errorf("sender %q has no matching userID", senderUser) return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
} }

View file

@ -492,7 +492,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send
} }
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
@ -573,7 +573,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?

View file

@ -159,7 +159,7 @@ func (r *Queryer) QueryStateAfterEvents(
} }
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
@ -407,7 +407,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event) }, event)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
@ -458,7 +458,7 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event) }, event)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
@ -651,7 +651,7 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState { if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
@ -1007,10 +1007,11 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID,
} }
} }
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
userID, err := spec.NewUserID(string(senderID), true) userID, err := spec.NewUserID(string(senderID), true)
if err == nil { if err == nil {
return userID, nil return userID, nil
} }
return r.DB.GetUserIDForSender(ctx, roomID, senderID) // TODO: pseudoIDs
return r.DB.GetUserIDForSender(ctx, roomID.String(), senderID)
} }

View file

@ -949,7 +949,7 @@ func (v *StateResolution) resolveConflictsV1(
} }
// Resolve the conflicts. // Resolve the conflicts.
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}) })
@ -1063,7 +1063,7 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents, conflictedEvents,
nonConflictedEvents, nonConflictedEvents,
authEvents, authEvents,
func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )

View file

@ -94,7 +94,7 @@ type MSC2836EventRelationshipsResponse struct {
func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
out := &EventRelationshipResponse{ out := &EventRelationshipResponse{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
Limited: res.Limited, Limited: res.Limited,

View file

@ -525,7 +525,7 @@ type testRoomserverAPI struct {
events map[string]*types.HeaderedEvent events map[string]*types.HeaderedEvent
} }
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -377,7 +377,11 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
return sp, fmt.Errorf("unexpected nil state_key") return sp, fmt.Errorf("unexpected nil state_key")
} }
userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return sp, err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey()))
if err != nil || userID == nil { if err != nil || userID == nil {
return sp, fmt.Errorf("failed getting userID for sender: %w", err) return sp, fmt.Errorf("failed getting userID for sender: %w", err)
} }
@ -404,7 +408,11 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
return return
} }
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey())) validRoomID, err := spec.NewRoomID(msg.Event.RoomID())
if err != nil {
return
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey()))
if err != nil || userID == nil { if err != nil || userID == nil {
return return
} }
@ -454,7 +462,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos) s.inviteStream.Advance(pduPos)
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID) validRoomID, err := spec.NewRoomID(msg.RoomID)
if err != nil {
log.WithFields(log.Fields{
"event_id": msg.EventID,
"room_id": msg.RoomID,
log.ErrorKey: err,
}).Errorf("roomID is invalid")
return
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, msg.TargetSenderID)
if err != nil || userID == nil { if err != nil || userID == nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": msg.EventID, "event_id": msg.EventID,

View file

@ -174,7 +174,11 @@ func TrackChangedUsers(
if membership != spec.Join { if membership != spec.Join {
continue continue
} }
user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) validRoomID, roomErr := spec.NewRoomID(roomID)
if roomErr != nil {
continue
}
user, queryErr := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey))
if queryErr != nil || user == nil { if queryErr != nil || user == nil {
continue continue
} }
@ -222,7 +226,11 @@ func TrackChangedUsers(
} }
// new user who we weren't previously sharing rooms with // new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
continue
}
user, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey))
if err != nil || user == nil { if err != nil || user == nil {
continue continue
} }

View file

@ -64,7 +64,7 @@ type mockRoomserverAPI struct {
roomIDToJoinedMembers map[string][]string roomIDToJoinedMembers map[string][]string
} }
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -101,13 +101,20 @@ func (n *Notifier) OnNewEvent(
n._removeEmptyUserStreams() n._removeEmptyUserStreams()
if ev != nil { if ev != nil {
validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: RoomID is invalid",
)
return
}
// Map this event's room_id to a list of joined users, and wake them up. // Map this event's room_id to a list of joined users, and wake them up.
usersToNotify := n._joinedUsers(ev.RoomID()) usersToNotify := n._joinedUsers(ev.RoomID())
// Map this event's room_id to a list of peeking devices, and wake them up. // Map this event's room_id to a list of peeking devices, and wake them up.
peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list. // If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil { if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey())) targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey()))
if err != nil { if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf( log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: Failed to find the userID for this event", "Notifier.OnNewEvent: Failed to find the userID for this event",

View file

@ -109,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
type TestRoomServer struct{ api.SyncRoomserverAPI } type TestRoomServer struct{ api.SyncRoomserverAPI }
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -200,10 +200,10 @@ func Context(
} }
} }
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
@ -211,7 +211,7 @@ func Context(
if filter.LazyLoadMembers { if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent) allEvents = append(allEvents, &requestedEvent)
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
@ -224,14 +224,14 @@ func Context(
} }
} }
ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { ev := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, requestedEvent) }, requestedEvent)
response := ContextRespsonse{ response := ContextRespsonse{
Event: &ev, Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient, EventsBefore: eventsBeforeClient,
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
} }

View file

@ -102,14 +102,28 @@ func GetEvent(
} }
sender := spec.UserID{} sender := spec.UserID{}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID()) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("roomID is invalid"),
}
}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID())
if err == nil && senderUserID != nil { if err == nil && senderUserID != nil {
sender = *senderUserID sender = *senderUserID
} }
sk := events[0].StateKey() sk := events[0].StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey())) evRoomID, err := spec.NewRoomID(events[0].RoomID())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("roomID is invalid"),
}
}
skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*events[0].StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View file

@ -152,7 +152,15 @@ func GetMemberships(
} }
} }
userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID())
if err != nil || userID == nil { if err != nil || userID == nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
return util.JSONResponse{ return util.JSONResponse{
@ -175,7 +183,7 @@ func GetMemberships(
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})}, })},
} }

View file

@ -273,7 +273,7 @@ func OnIncomingMessagesRequest(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})...) })...)
} }
@ -389,7 +389,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
"events_before": len(events), "events_before": len(events),
"events_after": len(filteredEvents), "events_after": len(filteredEvents),
}).Debug("applied history visibility (messages)") }).Debug("applied history visibility (messages)")
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), start, end, err }), start, end, err
} }

View file

@ -115,14 +115,18 @@ func Relations(
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents { for _, event := range filteredEvents {
sender := spec.UserID{} sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
continue
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View file

@ -205,9 +205,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos := make(map[string]ProfileInfoResponse) profileInfos := make(map[string]ProfileInfoResponse)
for _, ev := range append(eventsBefore, eventsAfter...) { for _, ev := range append(eventsBefore, eventsAfter...) {
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) validRoomID, roomErr := spec.NewRoomID(ev.RoomID())
if err != nil {
logrus.WithError(roomErr).WithField("room_id", ev.RoomID()).Warn("failed to query userprofile")
continue
}
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID())
if queryErr != nil { if queryErr != nil {
logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") logrus.WithError(queryErr).WithField("sender_id", ev.SenderID()).Warn("failed to query userprofile")
continue continue
} }
@ -231,14 +236,19 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
} }
sender := spec.UserID{} sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) validRoomID, roomErr := spec.NewRoomID(event.RoomID())
if err != nil {
logrus.WithError(roomErr).WithField("room_id", event.RoomID()).Warn("failed to query userprofile")
continue
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -248,10 +258,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
End: endToken.String(), End: endToken.String(),
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}), }),
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}), }),
ProfileInfo: profileInfos, ProfileInfo: profileInfos,
@ -272,7 +282,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}) })
} }

View file

@ -25,7 +25,7 @@ import (
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -65,14 +65,18 @@ func (p *InviteStreamProvider) IncrementalSync(
for roomID, inviteEvent := range invites { for roomID, inviteEvent := range invites {
user := spec.UserID{} user := spec.UserID{}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
if err != nil {
continue
}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, inviteEvent.SenderID())
if err == nil && sender != nil { if err == nil && sender != nil {
user = *sender user = *sender
} }
sk := inviteEvent.StateKey() sk := inviteEvent.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View file

@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
} }
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Join[delta.RoomID] = jr req.Response.Rooms.Join[delta.RoomID] = jr
@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms // TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Peek[delta.RoomID] = jr req.Response.Rooms.Peek[delta.RoomID] = jr
@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case spec.Ban: case spec.Ban:
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Leave[delta.RoomID] = lr req.Response.Rooms.Leave[delta.RoomID] = lr
@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
return jr, nil return jr, nil
@ -585,6 +585,10 @@ func (p *PDUStreamProvider) lazyLoadMembers(
if len(timelineEvents) == 0 { if len(timelineEvents) == 0 {
return stateEvents, nil return stateEvents, nil
} }
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
// Work out which memberships to include // Work out which memberships to include
timelineUsers := make(map[string]struct{}) timelineUsers := make(map[string]struct{})
if !incremental { if !incremental {
@ -606,8 +610,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
isGappedIncremental := limited && incremental isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list // We want this users membership event, keep it in the list
userID := "" userID := ""
stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey())) stateKeyUserID, queryErr := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && stateKeyUserID != nil { if queryErr == nil && stateKeyUserID != nil {
userID = stateKeyUserID.String() userID = stateKeyUserID.String()
} }
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID { if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {

View file

@ -40,7 +40,7 @@ type syncRoomserverAPI struct {
rooms []*test.Room rooms []*test.Room
} }
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -52,14 +52,18 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
continue // TODO: shouldn't happen? continue // TODO: shouldn't happen?
} }
sender := spec.UserID{} sender := spec.UserID{}
userID, err := userIDForSender(se.RoomID(), se.SenderID()) validRoomID, err := spec.NewRoomID(se.RoomID())
if err != nil {
continue
}
userID, err := userIDForSender(*validRoomID, se.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := se.StateKey() sk := se.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -95,14 +99,18 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
// It provides default logic for event.SenderID & event.StateKey -> userID conversions. // It provides default logic for event.SenderID & event.StateKey -> userID conversions.
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
sender := spec.UserID{} sender := spec.UserID{}
userID, err := userIDQuery(event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return ClientEvent{}
}
userID, err := userIDQuery(*validRoomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, err := userIDQuery(*validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View file

@ -39,7 +39,7 @@ var (
roomIDCounter = int64(0) roomIDCounter = int64(0)
) )
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View file

@ -302,14 +302,18 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch { switch {
case event.Type() == spec.MRoomMember: case event.Type() == spec.MRoomMember:
sender := spec.UserID{} sender := spec.UserID{}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, roomErr := spec.NewRoomID(event.RoomID())
if roomErr != nil {
return roomErr
}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if queryErr == nil && userID != nil { if queryErr == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if queryErr == nil && skUserID != nil { if queryErr == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -544,14 +548,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
} }
sender := spec.UserID{} sender := spec.UserID{}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if queryErr == nil && skUserID != nil { if queryErr == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -644,7 +652,11 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// user. Returns actions (including dont_notify). // user. Returns actions (including dont_notify).
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
user := "" user := ""
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil { if err == nil {
user = sender.String() user = sender.String()
} }
@ -682,7 +694,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize, roomSize: roomSize,
} }
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { rule, err := eval.MatchEvent(event.PDU, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {
@ -790,7 +802,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
} }
default: default:
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
return nil, err return nil, err

View file

@ -47,7 +47,7 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
@ -68,13 +68,13 @@ func Test_evaluatePushRules(t *testing.T) {
}{ }{
{ {
name: "m.receipt doesn't notify", name: "m.receipt doesn't notify",
eventContent: `{"type":"m.receipt"}`, eventContent: `{"type":"m.receipt","room_id":"!room:example.com"}`,
wantAction: pushrules.UnknownAction, wantAction: pushrules.UnknownAction,
wantActions: nil, wantActions: nil,
}, },
{ {
name: "m.reaction doesn't notify", name: "m.reaction doesn't notify",
eventContent: `{"type":"m.reaction"}`, eventContent: `{"type":"m.reaction","room_id":"!room:example.com"}`,
wantAction: pushrules.DontNotifyAction, wantAction: pushrules.DontNotifyAction,
wantActions: []*pushrules.Action{ wantActions: []*pushrules.Action{
{ {
@ -84,7 +84,7 @@ func Test_evaluatePushRules(t *testing.T) {
}, },
{ {
name: "m.room.message notifies", name: "m.room.message notifies",
eventContent: `{"type":"m.room.message"}`, eventContent: `{"type":"m.room.message","room_id":"!room:example.com"}`,
wantNotify: true, wantNotify: true,
wantAction: pushrules.NotifyAction, wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{ wantActions: []*pushrules.Action{
@ -93,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) {
}, },
{ {
name: "m.room.message highlights", name: "m.room.message highlights",
eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`, eventContent: `{"type":"m.room.message", "content": {"body": "test"},"room_id":"!room:example.com"}`,
wantNotify: true, wantNotify: true,
wantAction: pushrules.NotifyAction, wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{ wantActions: []*pushrules.Action{