Merge branch 'main' into token-registration

This commit is contained in:
santhoshivan23 2023-06-08 00:49:12 +05:30 committed by GitHub
commit 356edeb6b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
81 changed files with 1092 additions and 581 deletions

View file

@ -181,7 +181,9 @@ func (s *OutputRoomEventConsumer) sendEvents(
// Create the transaction body. // Create the transaction body.
transaction, err := json.Marshal( transaction, err := json.Marshal(
ApplicationServiceTransaction{ ApplicationServiceTransaction{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll), Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
}, },
) )
if err != nil { if err != nil {
@ -233,10 +235,16 @@ func (s *appserviceState) backoffAndPause(err error) error {
// //
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
user := ""
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil {
user = userID.String()
}
switch { switch {
case appservice.URL == "": case appservice.URL == "":
return false return false
case appservice.IsInterestedInUserID(event.Sender()): case appservice.IsInterestedInUserID(user):
return true return true
case appservice.IsInterestedInRoomID(event.RoomID()): case appservice.IsInterestedInRoomID(event.RoomID()):
return true return true

View file

@ -215,9 +215,35 @@ func RemoveLocalAlias(
alias string, alias string,
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
) util.JSONResponse { ) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{Err: "UserID for device is invalid"},
}
}
roomIDReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: alias}
roomIDRes := roomserverAPI.GetRoomIDForAliasResponse{}
err = rsAPI.GetRoomIDForAlias(req.Context(), &roomIDReq, &roomIDRes)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("The alias does not exist."),
}
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"},
}
}
queryReq := roomserverAPI.RemoveRoomAliasRequest{ queryReq := roomserverAPI.RemoveRoomAliasRequest{
Alias: alias, Alias: alias,
UserID: device.UserID, SenderID: deviceSenderID,
} }
var queryRes roomserverAPI.RemoveRoomAliasResponse var queryRes roomserverAPI.RemoveRoomAliasResponse
if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil {
@ -312,7 +338,21 @@ func SetVisibility(
// NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event
power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU)
if power.UserLevel(dev.UserID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { fullUserID, err := spec.NewUserID(dev.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"), JSON: spec.Forbidden("userID doesn't have power level to change visibility"),

View file

@ -66,7 +66,21 @@ func SendBan(
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
allowedToBan := pl.UserLevel(device.UserID) >= pl.Ban fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
}
}
allowedToBan := pl.UserLevel(senderID) >= pl.Ban
if !allowedToBan { if !allowedToBan {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -142,7 +156,21 @@ func SendKick(
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
allowedToKick := pl.UserLevel(device.UserID) >= pl.Kick fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
allowedToKick := pl.UserLevel(senderID) >= pl.Kick
if !allowedToKick { if !allowedToKick {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -151,7 +179,7 @@ func SendKick(
} }
var queryRes roomserverAPI.QueryMembershipForUserResponse var queryRes roomserverAPI.QueryMembershipForUserResponse
err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: body.UserID, UserID: body.UserID,
}, &queryRes) }, &queryRes)
@ -319,7 +347,7 @@ func buildMembershipEventDirect(
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
) (*types.HeaderedEvent, error) { ) (*types.HeaderedEvent, error) {
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: sender, SenderID: sender,
RoomID: roomID, RoomID: roomID,
Type: "m.room.member", Type: "m.room.member",
StateKey: &targetUserID, StateKey: &targetUserID,

View file

@ -363,12 +363,21 @@ func buildMembershipEvents(
) ([]*types.HeaderedEvent, error) { ) ([]*types.HeaderedEvent, error) {
evs := []*types.HeaderedEvent{} evs := []*types.HeaderedEvent{}
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
senderIDString := string(senderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: senderIDString,
RoomID: roomID, RoomID: roomID,
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &senderIDString,
} }
content := gomatrixserverlib.MemberContent{ content := gomatrixserverlib.MemberContent{
@ -378,7 +387,7 @@ func buildMembershipEvents(
content.DisplayName = newProfile.DisplayName content.DisplayName = newProfile.DisplayName
content.AvatarURL = newProfile.AvatarURL content.AvatarURL = newProfile.AvatarURL
if err := proto.SetContent(content); err != nil { if err = proto.SetContent(content); err != nil {
return nil, err return nil, err
} }

View file

@ -73,10 +73,25 @@ func SendRedaction(
} }
} }
fullUserID, userIDErr := spec.NewUserID(device.UserID, true)
if userIDErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if queryErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
// "Users may redact their own events, and any user with a power level greater than or equal // "Users may redact their own events, and any user with a power level greater than or equal
// to the redact power level of the room may redact events there" // to the redact power level of the room may redact events there"
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
allowedToRedact := ev.Sender() == device.UserID allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey
if !allowedToRedact { if !allowedToRedact {
plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,
@ -97,7 +112,7 @@ func SendRedaction(
), ),
} }
} }
allowedToRedact = pl.UserLevel(device.UserID) >= pl.Redact allowedToRedact = pl.UserLevel(senderID) >= pl.Redact
} }
if !allowedToRedact { if !allowedToRedact {
return util.JSONResponse{ return util.JSONResponse{
@ -114,10 +129,10 @@ func SendRedaction(
// create the new event and set all the fields we can // create the new event and set all the fields we can
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: device.UserID, SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Redacts: eventID, Redacts: eventID,
} }
err := proto.SetContent(r) err := proto.SetContent(r)
if err != nil { if err != nil {

View file

@ -266,16 +266,29 @@ func generateSendEvent(
evTime time.Time, evTime time.Time,
) (gomatrixserverlib.PDU, *util.JSONResponse) { ) (gomatrixserverlib.PDU, *util.JSONResponse) {
// parse the incoming http request // parse the incoming http request
userID := device.UserID fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Unable to find senderID for user"),
}
}
// create the new event and set all the fields we can // create the new event and set all the fields we can
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Type: eventType, Type: eventType,
StateKey: stateKey, StateKey: stateKey,
} }
err := proto.SetContent(r) err = proto.SetContent(r)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("proto.SetContent failed") util.GetLogger(ctx).WithError(err).Error("proto.SetContent failed")
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
@ -331,7 +344,9 @@ func generateSendEvent(
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(e.PDU, &provider); err != nil { if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client?

View file

@ -140,9 +140,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
// use the result of the previous QueryLatestEventsAndState response // use the result of the previous QueryLatestEventsAndState response
// to find the state event, if provided. // to find the state event, if provided.
for _, ev := range stateRes.StateEvents { for _, ev := range stateRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll), synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
) )
} }
} else { } else {
@ -162,9 +167,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
} }
} }
for _, ev := range stateAfterRes.StateEvents { for _, ev := range stateAfterRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll), synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
) )
} }
} }
@ -334,8 +344,13 @@ func OnIncomingStateTypeRequest(
} }
} }
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll), ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
} }
var res interface{} var res interface{}

View file

@ -355,8 +355,16 @@ func emit3PIDInviteEvent(
rsAPI api.ClientRoomserverAPI, rsAPI api.ClientRoomserverAPI,
evTime time.Time, evTime time.Time,
) error { ) error {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return err
}
sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
if err != nil {
return err
}
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Sender: device.UserID, SenderID: string(sender),
RoomID: roomID, RoomID: roomID,
Type: "m.room.third_party_invite", Type: "m.room.third_party_invite",
StateKey: &res.Token, StateKey: &res.Token,
@ -370,7 +378,7 @@ func emit3PIDInviteEvent(
PublicKeys: res.PublicKeys, PublicKeys: res.PublicKeys,
} }
if err := proto.SetContent(content); err != nil { if err = proto.SetContent(content); err != nil {
return err return err
} }

View file

@ -18,6 +18,7 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
// This is a utility for inspecting state snapshots and running state resolution // This is a utility for inspecting state snapshots and running state resolution
@ -182,7 +183,9 @@ func main() {
fmt.Println("Resolving state") fmt.Println("Resolving state")
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return roomserverDB.GetUserIDForSender(ctx, roomID, senderID)
},
) )
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -36,6 +36,14 @@ type fedRoomserverAPI struct {
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
} }
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
return spec.SenderID(userID.String()), nil
}
// PerformJoin will call this function // PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if f.inputRoomEvents == nil { if f.inputRoomEvents == nil {
@ -111,12 +119,13 @@ func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roo
defer f.fedClientMutex.Unlock() defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == roomID { if r.ID == roomID {
senderIDString := userID
res.RoomVersion = r.Version res.RoomVersion = r.Version
res.JoinEvent = gomatrixserverlib.ProtoEvent{ res.JoinEvent = gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: senderIDString,
RoomID: roomID, RoomID: roomID,
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &senderIDString,
Content: spec.RawJSON([]byte(`{"membership":"join"}`)), Content: spec.RawJSON([]byte(`{"membership":"join"}`)),
PrevEvents: r.ForwardExtremities(), PrevEvents: r.ForwardExtremities(),
} }

View file

@ -170,7 +170,7 @@ func (s *FederationInternalAPI) handleDatabaseKeys(
// in that case. If the key isn't valid right now, then by // in that case. If the key isn't valid right now, then by
// leaving it in the 'requests' map, we'll try to update the // leaving it in the 'requests' map, we'll try to update the
// key using the fetchers in handleFetcherKeys. // key using the fetchers in handleFetcherKeys.
if res.WasValidAt(now, true) { if res.WasValidAt(now, gomatrixserverlib.StrictValiditySignatureCheck) {
delete(requests, req) delete(requests, req)
} }
} }

View file

@ -154,17 +154,27 @@ func (r *FederationInternalAPI) performJoinUsingServer(
if err != nil { if err != nil {
return err return err
} }
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, roomID, *user)
if err != nil {
return err
}
joinInput := gomatrixserverlib.PerformJoinInput{ joinInput := gomatrixserverlib.PerformJoinInput{
UserID: user, UserID: user,
RoomID: room, SenderID: senderID,
ServerName: serverName, RoomID: room,
Content: content, ServerName: serverName,
Unsigned: unsigned, Content: content,
PrivateKey: r.cfg.Matrix.PrivateKey, Unsigned: unsigned,
KeyID: r.cfg.Matrix.KeyID, PrivateKey: r.cfg.Matrix.PrivateKey,
KeyRing: r.keyRing, KeyID: r.cfg.Matrix.KeyID,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName), KeyRing: r.keyRing,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
},
} }
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
@ -358,8 +368,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName), ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider,
) )
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
@ -406,7 +419,7 @@ func (r *FederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) userID, err := spec.NewUserID(request.UserID, true)
if err != nil { if err != nil {
return err return err
} }
@ -425,7 +438,7 @@ func (r *FederationInternalAPI) PerformLeave(
// request. // request.
respMakeLeave, err := r.federation.MakeLeave( respMakeLeave, err := r.federation.MakeLeave(
ctx, ctx,
origin, userID.Domain(),
serverName, serverName,
request.RoomID, request.RoomID,
request.UserID, request.UserID,
@ -446,9 +459,14 @@ func (r *FederationInternalAPI) PerformLeave(
// Set all the fields to be what they should be, this should be a no-op // Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd" // but it's possible that the remote server returned us something "odd"
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID)
if err != nil {
return err
}
senderIDString := string(senderID)
respMakeLeave.LeaveEvent.Type = spec.MRoomMember respMakeLeave.LeaveEvent.Type = spec.MRoomMember
respMakeLeave.LeaveEvent.Sender = request.UserID respMakeLeave.LeaveEvent.SenderID = senderIDString
respMakeLeave.LeaveEvent.StateKey = &request.UserID respMakeLeave.LeaveEvent.StateKey = &senderIDString
respMakeLeave.LeaveEvent.RoomID = request.RoomID respMakeLeave.LeaveEvent.RoomID = request.RoomID
respMakeLeave.LeaveEvent.Redacts = "" respMakeLeave.LeaveEvent.Redacts = ""
leaveEB := verImpl.NewEventBuilderFromProtoEvent(&respMakeLeave.LeaveEvent) leaveEB := verImpl.NewEventBuilderFromProtoEvent(&respMakeLeave.LeaveEvent)
@ -470,7 +488,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Build the leave event. // Build the leave event.
event, err := leaveEB.Build( event, err := leaveEB.Build(
time.Now(), time.Now(),
origin, userID.Domain(),
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
) )
@ -482,7 +500,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Try to perform a send_leave using the newly built event. // Try to perform a send_leave using the newly built event.
err = r.federation.SendLeave( err = r.federation.SendLeave(
ctx, ctx,
origin, userID.Domain(),
serverName, serverName,
event, event,
) )
@ -509,7 +527,7 @@ func (r *FederationInternalAPI) SendInvite(
event gomatrixserverlib.PDU, event gomatrixserverlib.PDU,
strippedState []gomatrixserverlib.InviteStrippedState, strippedState []gomatrixserverlib.InviteStrippedState,
) (gomatrixserverlib.PDU, error) { ) (gomatrixserverlib.PDU, error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', event.Sender()) inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -542,7 +560,7 @@ func (r *FederationInternalAPI) SendInvite(
return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
} }
inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) inviteRes, err := r.federation.SendInviteV2(ctx, inviter.Domain(), destination, inviteReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err)
} }
@ -635,6 +653,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error {
func federatedEventProvider( func federatedEventProvider(
ctx context.Context, federation fclient.FederationClient, ctx context.Context, federation fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
userIDForSender spec.UserIDForSender,
) gomatrixserverlib.EventProvider { ) gomatrixserverlib.EventProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
@ -684,7 +703,7 @@ func federatedEventProvider(
} }
// Check the signatures of the event. // Check the signatures of the event.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
} }

View file

@ -95,7 +95,7 @@ func Backfill(
} }
} }
// Query the roomserver. // Query the Roomserver.
if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil { if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed") util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -95,6 +95,9 @@ func InviteV2(
StateQuerier: rsAPI.StateQuerier(), StateQuerier: rsAPI.StateQuerier(),
InviteEvent: inviteReq.Event(), InviteEvent: inviteReq.Event(),
StrippedState: inviteReq.InviteRoomState(), StrippedState: inviteReq.InviteRoomState(),
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
} }
event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI)
if jsonErr != nil { if jsonErr != nil {
@ -185,6 +188,9 @@ func InviteV1(
StateQuerier: rsAPI.StateQuerier(), StateQuerier: rsAPI.StateQuerier(),
InviteEvent: event, InviteEvent: event,
StrippedState: strippedState, StrippedState: strippedState,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
} }
event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI)
if jsonErr != nil { if jsonErr != nil {

View file

@ -15,7 +15,6 @@
package routing package routing
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"sort" "sort"
@ -33,53 +32,6 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
type JoinRoomQuerier struct {
roomserver api.FederationRoomserverAPI
}
func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) {
return rq.roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey)
}
func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) {
return rq.roomserver.InvitePending(ctx, roomID, userID)
}
func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := rq.roomserver.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err
}
req := api.QueryServerJoinedToRoomRequest{
ServerName: localServerName,
RoomID: roomID.String(),
}
res := api.QueryServerJoinedToRoomResponse{}
if err = rq.roomserver.QueryServerJoinedToRoom(ctx, &req, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
}
userJoinedToRoom, err := rq.roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
}
locallyJoinedUsers, err := rq.roomserver.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID))
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
}
return &gomatrixserverlib.RestrictedRoomJoinInfo{
LocalServerInRoom: res.RoomExists && res.IsInRoom,
UserJoinedToRoom: userJoinedToRoom,
JoinedUsers: locallyJoinedUsers,
}, nil
}
// MakeJoin implements the /make_join API // MakeJoin implements the /make_join API
func MakeJoin( func MakeJoin(
httpReq *http.Request, httpReq *http.Request,
@ -103,7 +55,7 @@ func MakeJoin(
RoomID: roomID.String(), RoomID: roomID.String(),
} }
res := api.QueryServerJoinedToRoomResponse{} res := api.QueryServerJoinedToRoomResponse{}
if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -112,26 +64,26 @@ func MakeJoin(
} }
createJoinTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { createJoinTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) {
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil { if signErr != nil {
util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination())
return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination()))
} }
queryRes := api.QueryLatestEventsAndStateResponse{ queryRes := api.QueryLatestEventsAndStateResponse{
RoomVersion: roomVersion, RoomVersion: roomVersion,
} }
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) event, signErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes)
switch e := err.(type) { switch e := signErr.(type) {
case nil: case nil:
case eventutil.ErrRoomNoExists: case eventutil.ErrRoomNoExists:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.NotFound("Room does not exist") return nil, nil, spec.NotFound("Room does not exist")
case gomatrixserverlib.BadJSONError: case gomatrixserverlib.BadJSONError:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.BadJSON(e.Error()) return nil, nil, spec.BadJSON(e.Error())
default: default:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.InternalServerError{} return nil, nil, spec.InternalServerError{}
} }
@ -142,20 +94,33 @@ func MakeJoin(
return event, stateEvents, nil return event, stateEvents, nil
} }
roomQuerier := JoinRoomQuerier{ roomQuerier := api.JoinRoomQuerier{
roomserver: rsAPI, Roomserver: rsAPI,
}
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
} }
input := gomatrixserverlib.HandleMakeJoinInput{ input := gomatrixserverlib.HandleMakeJoinInput{
Context: httpReq.Context(), Context: httpReq.Context(),
UserID: userID, UserID: userID,
RoomID: roomID, SenderID: senderID,
RoomVersion: roomVersion, RoomID: roomID,
RemoteVersions: remoteVersions, RoomVersion: roomVersion,
RequestOrigin: request.Origin(), RemoteVersions: remoteVersions,
LocalServerName: cfg.Matrix.ServerName, RequestOrigin: request.Origin(),
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerName: cfg.Matrix.ServerName,
RoomQuerier: &roomQuerier, LocalServerInRoom: res.RoomExists && res.IsInRoom,
RoomQuerier: &roomQuerier,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
BuildEventTemplate: createJoinTemplate, BuildEventTemplate: createJoinTemplate,
} }
response, internalErr := gomatrixserverlib.HandleMakeJoin(input) response, internalErr := gomatrixserverlib.HandleMakeJoin(input)
@ -250,6 +215,9 @@ func SendJoin(
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys, Verifier: keys,
MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI},
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
} }
response, joinErr := gomatrixserverlib.HandleSendJoin(input) response, joinErr := gomatrixserverlib.HandleSendJoin(input)
switch e := joinErr.(type) { switch e := joinErr.(type) {

View file

@ -50,7 +50,7 @@ func MakeLeave(
RoomID: roomID.String(), RoomID: roomID.String(),
} }
res := api.QueryServerJoinedToRoomResponse{} res := api.QueryServerJoinedToRoomResponse{}
if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -59,24 +59,24 @@ func MakeLeave(
} }
createLeaveTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { createLeaveTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) {
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil { if signErr != nil {
util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination())
return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination()))
} }
queryRes := api.QueryLatestEventsAndStateResponse{} queryRes := api.QueryLatestEventsAndStateResponse{}
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) event, buildErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes)
switch e := err.(type) { switch e := buildErr.(type) {
case nil: case nil:
case eventutil.ErrRoomNoExists: case eventutil.ErrRoomNoExists:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.NotFound("Room does not exist") return nil, nil, spec.NotFound("Room does not exist")
case gomatrixserverlib.BadJSONError: case gomatrixserverlib.BadJSONError:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.BadJSON(e.Error()) return nil, nil, spec.BadJSON(e.Error())
default: default:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.InternalServerError{} return nil, nil, spec.InternalServerError{}
} }
@ -87,14 +87,27 @@ func MakeLeave(
return event, stateEvents, nil return event, stateEvents, nil
} }
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
input := gomatrixserverlib.HandleMakeLeaveInput{ input := gomatrixserverlib.HandleMakeLeaveInput{
UserID: userID, UserID: userID,
SenderID: senderID,
RoomID: roomID, RoomID: roomID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RequestOrigin: request.Origin(), RequestOrigin: request.Origin(),
LocalServerName: cfg.Matrix.ServerName, LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerInRoom: res.RoomExists && res.IsInRoom,
BuildEventTemplate: createLeaveTemplate, BuildEventTemplate: createLeaveTemplate,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
},
} }
response, internalErr := gomatrixserverlib.HandleMakeLeave(input) response, internalErr := gomatrixserverlib.HandleMakeLeave(input)
@ -213,7 +226,7 @@ func SendLeave(
JSON: spec.BadJSON("No state key was provided in the leave event."), JSON: spec.BadJSON("No state key was provided in the leave event."),
} }
} }
if !event.StateKeyEquals(event.Sender()) { if !event.StateKeyEquals(string(event.SenderID())) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("Event state key must match the event sender."), JSON: spec.BadJSON("Event state key must match the event sender."),
@ -223,13 +236,13 @@ func SendLeave(
// Check that the sender belongs to the server that is sending us // Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender // the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both. // and the state key are equal so we don't need to check both.
var serverName spec.ServerName sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID())
if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender of the join is invalid"), JSON: spec.Forbidden("The sender of the join is invalid"),
} }
} else if serverName != request.Origin() { } else if sender.Domain() != request.Origin() {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender does not match the server that originated the request"), JSON: spec.Forbidden("The sender does not match the server that originated the request"),
@ -291,10 +304,10 @@ func SendLeave(
} }
} }
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: serverName, ServerName: sender.Domain(),
Message: redacted, Message: redacted,
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
StrictValidityChecking: true, ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck,
}} }}
verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
if err != nil { if err != nil {

View file

@ -34,7 +34,7 @@ import (
) )
const ( const (
// Event was passed to the roomserver // Event was passed to the Roomserver
MetricsOutcomeOK = "ok" MetricsOutcomeOK = "ok"
// Event failed to be processed // Event failed to be processed
MetricsOutcomeFail = "fail" MetricsOutcomeFail = "fail"

View file

@ -140,22 +140,24 @@ func ExchangeThirdPartyInvite(
} }
} }
_, senderDomain, err := cfg.Matrix.SplitLocalID('@', proto.Sender) userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID))
if err != nil { if err != nil || userID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid sender ID: " + err.Error()), JSON: spec.BadJSON("Invalid sender ID"),
} }
} }
senderDomain := userID.Domain()
// Check that the state key is correct. // Check that the state key is correct.
_, targetDomain, err := gomatrixserverlib.SplitID('@', *proto.StateKey) targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey))
if err != nil { if err != nil || targetUserID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("The event's state key isn't a Matrix user ID"), JSON: spec.BadJSON("The event's state key isn't a Matrix user ID"),
} }
} }
targetDomain := targetUserID.Domain()
// Check that the target user is from the requesting homeserver. // Check that the target user is from the requesting homeserver.
if targetDomain != request.Origin() { if targetDomain != request.Origin() {
@ -223,7 +225,7 @@ func ExchangeThirdPartyInvite(
} }
} }
// Send the event to the roomserver // Send the event to the Roomserver
if err = api.SendEvents( if err = api.SendEvents(
httpReq.Context(), rsAPI, httpReq.Context(), rsAPI,
api.KindNew, api.KindNew,
@ -271,7 +273,7 @@ func createInviteFrom3PIDInvite(
// Build the event // Build the event
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Type: "m.room.member", Type: "m.room.member",
Sender: inv.Sender, SenderID: inv.Sender,
RoomID: inv.RoomID, RoomID: inv.RoomID,
StateKey: &inv.MXID, StateKey: &inv.MXID,
} }
@ -324,7 +326,7 @@ func buildMembershipEvent(
return nil, errors.New("expecting state tuples for event builder, got none") return nil, errors.New("expecting state tuples for event builder, got none")
} }
// Ask the roomserver for information about this room // Ask the Roomserver for information about this room
queryReq := api.QueryLatestEventsAndStateRequest{ queryReq := api.QueryLatestEventsAndStateRequest{
RoomID: protoEvent.RoomID, RoomID: protoEvent.RoomID,
StateToFetch: eventsNeeded.Tuples(), StateToFetch: eventsNeeded.Tuples(),

4
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16
@ -34,7 +34,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_golang v1.13.0
github.com/sirupsen/logrus v1.9.2 github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.2 github.com/stretchr/testify v1.8.2
github.com/tidwall/gjson v1.14.4 github.com/tidwall/gjson v1.14.4
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5

8
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 h1:Kh1TNvJDhWN5CdgtICNUC4G0wV2km51LGr46Dvl153A= github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=

View file

@ -28,7 +28,7 @@ func (c Caches) GetServerKey(
) (gomatrixserverlib.PublicKeyLookupResult, bool) { ) (gomatrixserverlib.PublicKeyLookupResult, bool) {
key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
val, found := c.ServerKeys.Get(key) val, found := c.ServerKeys.Get(key)
if found && !val.WasValidAt(timestamp, true) { if found && !val.WasValidAt(timestamp, gomatrixserverlib.StrictValiditySignatureCheck) {
// The key wasn't valid at the requested timestamp so don't // The key wasn't valid at the requested timestamp so don't
// return it. The caller will have to work out what to do. // return it. The caller will have to work out what to do.
c.ServerKeys.Unset(key) c.ServerKeys.Unset(key)

View file

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
// A RuleSetEvaluator encapsulates context to evaluate an event // A RuleSetEvaluator encapsulates context to evaluate an event
@ -27,7 +28,7 @@ type EvaluationContext interface {
// HasPowerLevel returns whether the user has at least the given // HasPowerLevel returns whether the user has at least the given
// power in the room of the current event. // power in the room of the current event.
HasPowerLevel(userID, levelKey string) (bool, error) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error)
} }
// A kindAndRules is just here to simplify iteration of the (ordered) // A kindAndRules is just here to simplify iteration of the (ordered)
@ -53,7 +54,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat
// MatchEvent returns the first matching rule. Returns nil if there // MatchEvent returns the first matching rule. Returns nil if there
// was no match rule. // was no match rule.
func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, error) { func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) (*Rule, error) {
// TODO: server-default rules have lower priority than user rules, // TODO: server-default rules have lower priority than user rules,
// but they are stored together with the user rules. It's a bit // but they are stored together with the user rules. It's a bit
// unclear what the specification (11.14.1.4 Predefined rules) // unclear what the specification (11.14.1.4 Predefined rules)
@ -68,7 +69,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err
if rule.Default != defRules { if rule.Default != defRules {
continue continue
} }
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec, userIDForSender)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,7 +84,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err
return nil, nil return nil, nil
} }
func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) { func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext, userIDForSender spec.UserIDForSender) (bool, error) {
if !rule.Enabled { if !rule.Enabled {
return false, nil return false, nil
} }
@ -113,7 +114,12 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
return rule.RuleID == event.RoomID(), nil return rule.RuleID == event.RoomID(), nil
case SenderKind: case SenderKind:
return rule.RuleID == event.Sender(), nil userID := ""
sender, err := userIDForSender(event.RoomID(), event.SenderID())
if err == nil {
userID = sender.String()
}
return rule.RuleID == userID, nil
default: default:
return false, nil return false, nil
@ -143,7 +149,7 @@ func conditionMatches(cond *Condition, event gomatrixserverlib.PDU, ec Evaluatio
return cmp(n), nil return cmp(n), nil
case SenderNotificationPermissionCondition: case SenderNotificationPermissionCondition:
return ec.HasPowerLevel(event.Sender(), cond.Key) return ec.HasPowerLevel(event.SenderID(), cond.Key)
default: default:
return false, nil return false, nil

View file

@ -5,8 +5,13 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func TestRuleSetEvaluatorMatchEvent(t *testing.T) { func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
ev := mustEventFromJSON(t, `{}`) ev := mustEventFromJSON(t, `{}`)
defaultEnabled := &Rule{ defaultEnabled := &Rule{
@ -45,7 +50,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {
rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet) rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet)
got, err := rse.MatchEvent(tst.Event) got, err := rse.MatchEvent(tst.Event, UserIDForSender)
if err != nil { if err != nil {
t.Fatalf("MatchEvent failed: %v", err) t.Fatalf("MatchEvent failed: %v", err)
} }
@ -82,15 +87,15 @@ func TestRuleMatches(t *testing.T) {
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true},
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false},
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true},
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false},
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true},
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false},
} }
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil, UserIDForSender)
if err != nil { if err != nil {
t.Fatalf("ruleMatches failed: %v", err) t.Fatalf("ruleMatches failed: %v", err)
} }
@ -153,8 +158,8 @@ type fakeEvaluationContext struct{ memberCount int }
func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" } func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
func (f fakeEvaluationContext) RoomMemberCount() (int, error) { return f.memberCount, nil } func (f fakeEvaluationContext) RoomMemberCount() (int, error) { return f.memberCount, nil }
func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) { func (fakeEvaluationContext) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) {
return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil return senderID == "@poweruser:example.com" && levelKey == "powerlevel", nil
} }
func TestPatternMatches(t *testing.T) { func TestPatternMatches(t *testing.T) {

View file

@ -167,7 +167,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
} }
continue continue
} }
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = fclient.PDUResult{ results[event.EventID()] = fclient.PDUResult{
Error: err.Error(), Error: err.Error(),

View file

@ -70,6 +70,10 @@ type FakeRsAPI struct {
bannedFromRoom bool bannedFromRoom bool
} }
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func (r *FakeRsAPI) QueryRoomVersionForRoom( func (r *FakeRsAPI) QueryRoomVersionForRoom(
ctx context.Context, ctx context.Context,
roomID string, roomID string,
@ -638,6 +642,10 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
} }
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func (t *testRoomserverAPI) InputRoomEvents( func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *rsAPI.InputRoomEventsRequest, request *rsAPI.InputRoomEventsRequest,

View file

@ -14,7 +14,11 @@
package api package api
import "regexp" import (
"regexp"
"github.com/matrix-org/gomatrixserverlib/spec"
)
// SetRoomAliasRequest is a request to SetRoomAlias // SetRoomAliasRequest is a request to SetRoomAlias
type SetRoomAliasRequest struct { type SetRoomAliasRequest struct {
@ -62,7 +66,7 @@ type GetAliasesForRoomIDResponse struct {
// RemoveRoomAliasRequest is a request to RemoveRoomAlias // RemoveRoomAliasRequest is a request to RemoveRoomAlias
type RemoveRoomAliasRequest struct { type RemoveRoomAliasRequest struct {
// ID of the user removing the alias // ID of the user removing the alias
UserID string `json:"user_id"` SenderID spec.SenderID `json:"user_id"`
// The room alias to remove // The room alias to remove
Alias string `json:"alias"` Alias string `json:"alias"`
} }

View file

@ -32,6 +32,16 @@ func (e ErrNotAllowed) Error() string {
return e.Err.Error() return e.Err.Error()
} }
type RestrictedJoinAPI interface {
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error)
RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error)
QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error)
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
}
// RoomserverInputAPI is used to write events to the room server. // RoomserverInputAPI is used to write events to the room server.
type RoomserverInternalAPI interface { type RoomserverInternalAPI interface {
SyncRoomserverAPI SyncRoomserverAPI
@ -39,6 +49,7 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI ClientRoomserverAPI
UserRoomserverAPI UserRoomserverAPI
FederationRoomserverAPI FederationRoomserverAPI
QuerySenderIDAPI
// needed to avoid chicken and egg scenario when setting up the // needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs // interdependencies between the roomserver and other input APIs
@ -65,6 +76,11 @@ type InputRoomEventsAPI interface {
) )
} }
type QuerySenderIDAPI interface {
QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
}
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.
type QueryLatestEventsAndStateAPI interface { type QueryLatestEventsAndStateAPI interface {
QueryLatestEventsAndState(ctx context.Context, req *QueryLatestEventsAndStateRequest, res *QueryLatestEventsAndStateResponse) error QueryLatestEventsAndState(ctx context.Context, req *QueryLatestEventsAndStateRequest, res *QueryLatestEventsAndStateResponse) error
@ -92,6 +108,7 @@ type QueryEventsAPI interface {
type SyncRoomserverAPI interface { type SyncRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
@ -132,6 +149,7 @@ type SyncRoomserverAPI interface {
} }
type AppserviceRoomserverAPI interface { type AppserviceRoomserverAPI interface {
QuerySenderIDAPI
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID. // which room to use by querying the first events roomID.
QueryEventsByID( QueryEventsByID(
@ -158,6 +176,7 @@ type ClientRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QueryEventsAPI QueryEventsAPI
QuerySenderIDAPI
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
@ -190,6 +209,7 @@ type ClientRoomserverAPI interface {
} }
type UserRoomserverAPI interface { type UserRoomserverAPI interface {
QuerySenderIDAPI
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
KeyserverRoomserverAPI KeyserverRoomserverAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
@ -199,9 +219,12 @@ type UserRoomserverAPI interface {
} }
type FederationRoomserverAPI interface { type FederationRoomserverAPI interface {
RestrictedJoinAPI
InputRoomEventsAPI InputRoomEventsAPI
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
@ -223,7 +246,7 @@ type FederationRoomserverAPI interface {
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error)
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -351,26 +352,6 @@ type QueryServerBannedFromRoomResponse struct {
Banned bool `json:"banned"` Banned bool `json:"banned"`
} }
type QueryRestrictedJoinAllowedRequest struct {
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
}
type QueryRestrictedJoinAllowedResponse struct {
// True if the room membership is restricted by the join rule being set to "restricted"
Restricted bool `json:"restricted"`
// True if our local server is joined to all of the allowed rooms specified in the "allow"
// key of the join rule, false if we are missing from some of them and therefore can't
// reliably decide whether or not we can satisfy the join
Resident bool `json:"resident"`
// True if the restricted join is allowed because we found the membership in one of the
// allowed rooms from the join rule, false if not
Allowed bool `json:"allowed"`
// Contains the user ID of the selected user ID that has power to issue invites, this will
// get populated into the "join_authorised_via_users_server" content in the membership
AuthorisedVia string `json:"authorised_via,omitempty"`
}
// MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode. // MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) { func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) {
se := make(map[string]string) se := make(map[string]string)
@ -459,14 +440,61 @@ type QueryLeftUsersResponse struct {
LeftUsers []string `json:"user_ids"` LeftUsers []string `json:"user_ids"`
} }
type JoinRoomQuerier struct {
Roomserver RestrictedJoinAPI
}
func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) {
return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey)
}
func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) {
return rq.Roomserver.InvitePending(ctx, roomID, userID)
}
func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err
}
req := QueryServerJoinedToRoomRequest{
ServerName: localServerName,
RoomID: roomID.String(),
}
res := QueryServerJoinedToRoomResponse{}
if err = rq.Roomserver.QueryServerJoinedToRoom(ctx, &req, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
}
userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
}
locallyJoinedUsers, err := rq.Roomserver.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID))
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
}
return &gomatrixserverlib.RestrictedRoomJoinInfo{
LocalServerInRoom: res.RoomExists && res.IsInRoom,
UserJoinedToRoom: userJoinedToRoom,
JoinedUsers: locallyJoinedUsers,
}, nil
}
type MembershipQuerier struct { type MembershipQuerier struct {
Roomserver FederationRoomserverAPI Roomserver FederationRoomserverAPI
} }
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
req := QueryMembershipForUserRequest{ req := QueryMembershipForUserRequest{
RoomID: roomID.String(), RoomID: roomID.String(),
UserID: userID.String(), UserID: string(senderID),
} }
res := QueryMembershipForUserResponse{} res := QueryMembershipForUserResponse{}
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res)

View file

@ -119,11 +119,6 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest, request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse, response *api.RemoveRoomAliasResponse,
) error { ) error {
_, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID)
if err != nil {
return err
}
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
@ -134,13 +129,19 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return nil return nil
} }
sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID)
if err != nil || sender == nil {
return fmt.Errorf("r.QueryUserIDForSender: %w", err)
}
virtualHost := sender.Domain()
response.Found = true response.Found = true
creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err)
} }
if creatorID != request.UserID { if spec.SenderID(creatorID) != request.SenderID {
var plEvent *types.HeaderedEvent var plEvent *types.HeaderedEvent
var pls *gomatrixserverlib.PowerLevelContent var pls *gomatrixserverlib.PowerLevelContent
@ -154,7 +155,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return fmt.Errorf("plEvent.PowerLevels: %w", err) return fmt.Errorf("plEvent.PowerLevels: %w", err)
} }
if pls.UserLevel(request.UserID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) {
response.Removed = false response.Removed = false
return nil return nil
} }
@ -172,23 +173,24 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return err return err
} }
sender := request.UserID senderID := request.SenderID
if request.UserID != ev.Sender() { if request.SenderID != ev.SenderID() {
sender = ev.Sender() senderID = ev.SenderID()
} }
sender, err := r.QueryUserIDForSender(ctx, roomID, senderID)
_, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) if err != nil || sender == nil {
if err != nil {
return err return err
} }
senderDomain := sender.Domain()
identity, err := r.Cfg.Global.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Global.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
return err return err
} }
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Sender: sender, SenderID: string(senderID),
RoomID: ev.RoomID(), RoomID: ev.RoomID(),
Type: ev.Type(), Type: ev.Type(),
StateKey: ev.StateKey(), StateKey: ev.StateKey(),

View file

@ -94,6 +94,7 @@ func NewRoomserverAPI(
Cache: caches, Cache: caches,
IsLocalServerName: dendriteCfg.Global.IsLocalServerName, IsLocalServerName: dendriteCfg.Global.IsLocalServerName,
ServerACLs: serverACLs, ServerACLs: serverACLs,
Cfg: dendriteCfg,
}, },
enableMetrics: enableMetrics, enableMetrics: enableMetrics,
// perform-er structs get initialised when we have a federation sender to use // perform-er structs get initialised when we have a federation sender to use

View file

@ -76,7 +76,9 @@ func CheckForSoftFail(
} }
// Check if the event is allowed. // Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return db.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
// return true, nil // return true, nil
return true, err return true, err
} }
@ -137,8 +139,8 @@ func (ae *authEvents) JoinRules() (gomatrixserverlib.PDU, error) {
} }
// Memmber implements gomatrixserverlib.AuthEventProvider // Memmber implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) Member(stateKey string) (gomatrixserverlib.PDU, error) { func (ae *authEvents) Member(stateKey spec.SenderID) (gomatrixserverlib.PDU, error) {
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil return ae.lookupEvent(types.MRoomMemberNID, string(stateKey)), nil
} }
// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider // ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider

View file

@ -128,9 +128,13 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent { if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
} }
_, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil { if err != nil {
return fmt.Errorf("event has invalid sender %q", input.Event.Sender()) return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
}
senderDomain := spec.ServerName("")
if sender != nil {
senderDomain = sender.Domain()
} }
// If we already know about this outlier and it hasn't been rejected // If we already know about this outlier and it hasn't been rejected
@ -193,7 +197,9 @@ func (r *Inputer) processRoomEvent(
serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) serverRes.ServerNames = append(serverRes.ServerNames, input.Origin)
delete(servers, input.Origin) delete(servers, input.Origin)
} }
if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { // Only perform this check if the sender mxid_mapping can be resolved.
// Don't fail processing the event if we have no mxid_maping.
if sender != nil && senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName {
serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) serverRes.ServerNames = append(serverRes.ServerNames, senderDomain)
delete(servers, senderDomain) delete(servers, senderDomain)
} }
@ -276,7 +282,9 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then // Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted. // we consider the event to be "rejected" — it will still be persisted.
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
isRejected = true isRejected = true
rejectionErr = err rejectionErr = err
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
@ -493,7 +501,7 @@ func (r *Inputer) processRoomEvent(
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
oldRoomID := event.RoomID() oldRoomID := event.RoomID()
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, string(event.SenderID()))
} }
// processStateBefore works out what the state is before the event and // processStateBefore works out what the state is before the event and
@ -579,7 +587,9 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents( stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent), gomatrixserverlib.ToPDUs(stateBeforeEvent),
) )
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
return return
} }
@ -690,7 +700,9 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply // Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem // skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway. // if a critical event is missing anyway.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue nextAuthEvent continue nextAuthEvent
} }
@ -706,7 +718,9 @@ nextAuthEvent:
} }
// Check if the auth event should be rejected. // Check if the auth event should be rejected.
err := gomatrixserverlib.Allowed(authEvent, auth) err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
})
if isRejected = err != nil; isRejected { if isRejected = err != nil; isRejected {
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
} }
@ -828,11 +842,13 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
continue continue
} }
// TODO: pseudoIDs: get userID for room using state key (which is now senderID)
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
if err != nil { if err != nil {
continue continue
} }
// TODO: pseudoIDs: query account by state key (which is now senderID)
accountRes := &userAPI.QueryAccountByLocalpartResponse{} accountRes := &userAPI.QueryAccountByLocalpartResponse{}
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
Localpart: localpart, Localpart: localpart,
@ -859,7 +875,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
RoomID: event.RoomID(), RoomID: event.RoomID(),
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: &stateKey, StateKey: &stateKey,
Sender: stateKey, SenderID: stateKey,
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }

View file

@ -58,7 +58,9 @@ func Test_EventAuth(t *testing.T) {
} }
// Finally check that the event is NOT allowed // Finally check that the event is NOT allowed
if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil { if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}); err == nil {
t.Fatalf("event should not be allowed, but it was") t.Fatalf("event should not be allowed, but it was")
} }
} }

View file

@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...) stateEventList = append(stateEventList, state.StateEvents...)
} }
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID)
},
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// apply the current event // apply the current event
retryAllowedState: retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true)
@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed. // will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue continue
} }
missingEvents = append(missingEvents, t.cacheAndReturn(ev)) missingEvents = append(missingEvents, t.cacheAndReturn(ev))
@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(), StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(), AuthEvents: state.GetAuthEvents(),
}, roomVersion, t.keys, nil) }, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
} }
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }
return t.cacheAndReturn(event), nil return t.cacheAndReturn(event), nil
} }
func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error { func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil) authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents { for i := range stateEvents {
err := authUsingState.AddEvent(stateEvents[i]) err := authUsingState.AddEvent(stateEvents[i])
@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli
return err return err
} }
} }
return gomatrixserverlib.Allowed(e, &authUsingState) return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender)
} }
func (t *missingStateReq) hadEvent(eventID string) { func (t *missingStateReq) hadEvent(eventID string) {

View file

@ -96,14 +96,15 @@ func (r *Admin) PerformAdminEvacuateRoom(
RoomID: roomID, RoomID: roomID,
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: &stateKey, StateKey: &stateKey,
Sender: stateKey, SenderID: stateKey,
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }
_, senderDomain, err = gomatrixserverlib.SplitID('@', fledglingEvent.Sender) userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID))
if err != nil { if err != nil || userID == nil {
continue continue
} }
senderDomain = userID.Domain()
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
return nil, err return nil, err
@ -233,10 +234,11 @@ func (r *Admin) PerformAdminDownloadState(
ctx context.Context, ctx context.Context,
roomID, userID string, serverName spec.ServerName, roomID, userID string, serverName spec.ServerName,
) error { ) error {
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) fullUserID, err := spec.NewUserID(userID, true)
if err != nil { if err != nil {
return err return err
} }
senderDomain := fullUserID.Domain()
roomInfo, err := r.DB.RoomInfo(ctx, roomID) roomInfo, err := r.DB.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
@ -262,13 +264,17 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
} }
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue continue
} }
authEventMap[authEvent.EventID()] = authEvent authEventMap[authEvent.EventID()] = authEvent
} }
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue continue
} }
stateEventMap[stateEvent.EventID()] = stateEvent stateEventMap[stateEvent.EventID()] = stateEvent
@ -287,11 +293,15 @@ func (r *Admin) PerformAdminDownloadState(
stateIDs = append(stateIDs, stateEvent.EventID()) stateIDs = append(stateIDs, stateEvent.EventID())
} }
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return err
}
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Type: "org.matrix.dendrite.state_download", Type: "org.matrix.dendrite.state_download",
Sender: userID, SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Content: spec.RawJSON("{}"), Content: spec.RawJSON("{}"),
} }
eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto) eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto)

View file

@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events" // Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
},
) )
// Only return an error if we really couldn't get any events. // Only return an error if we really couldn't get any events.
if err != nil && len(events) == 0 { if err != nil && len(events) == 0 {
@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue continue
} }
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
})
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to load and verify event") logger.WithError(err).Warn("failed to load and verify event")
continue continue
@ -484,8 +488,8 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates. // Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool) serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents { for _, event := range memberEvents {
if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil { if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
serverSet[senderDomain] = true serverSet[sender.Domain()] = true
} }
} }
var servers []spec.ServerName var servers []spec.ServerName

View file

@ -270,11 +270,19 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
var builtEvents []*types.HeaderedEvent var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
for i, e := range eventsToMake { for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
Sender: userID.String(), SenderID: string(senderID),
RoomID: roomID.String(), RoomID: roomID.String(),
Type: e.Type, Type: e.Type,
StateKey: &e.StateKey, StateKey: &e.StateKey,
@ -308,7 +316,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return c.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -407,11 +417,28 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
// Process the invites. // Process the invites.
var inviteEvent *types.HeaderedEvent var inviteEvent *types.HeaderedEvent
for _, invitee := range createRequest.InvitedUsers { for _, invitee := range createRequest.InvitedUsers {
inviteeUserID, userIDErr := spec.NewUserID(invitee, true)
if userIDErr != nil {
util.GetLogger(ctx).WithError(userIDErr).Error("invalid UserID")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID)
if queryErr != nil {
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
inviteeString := string(inviteeSenderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID.String(), SenderID: string(senderID),
RoomID: roomID.String(), RoomID: roomID.String(),
Type: "m.room.member", Type: "m.room.member",
StateKey: &invitee, StateKey: &inviteeString,
} }
content := gomatrixserverlib.MemberContent{ content := gomatrixserverlib.MemberContent{

View file

@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership(
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var outputUpdates []api.OutputEvent var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater var updater *shared.MembershipUpdater
_, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey())
userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
if err != nil { if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain())
if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
} }
@ -125,9 +126,9 @@ func (r *Inviter) PerformInvite(
) error { ) error {
event := req.Event event := req.Event
sender, err := spec.NewUserID(event.Sender(), true) sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil { if err != nil {
return spec.InvalidParam("The user ID is invalid") return spec.InvalidParam("The sender user ID is invalid")
} }
if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
@ -147,14 +148,23 @@ func (r *Inviter) PerformInvite(
return err return err
} }
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser)
if err != nil {
return fmt.Errorf("failed looking up senderID for invited user")
}
input := gomatrixserverlib.PerformInviteInput{ input := gomatrixserverlib.PerformInviteInput{
RoomID: *validRoomID, RoomID: *validRoomID,
InviteEvent: event.PDU, InviteEvent: event.PDU,
InvitedUser: *invitedUser, InvitedUser: *invitedUser,
InvitedSenderID: invitedSenderID,
IsTargetLocal: isTargetLocal, IsTargetLocal: isTargetLocal,
StrippedState: req.InviteRoomState, StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB}, StateQuerier: &QueryState{r.DB},
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
},
} }
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
if err != nil { if err != nil {

View file

@ -175,15 +175,20 @@ func (r *Joiner) performJoinRoomByID(
} }
// Prepare the template for the join event. // Prepare the template for the join event.
userID := req.UserID userID, err := spec.NewUserID(req.UserID, true)
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", userID, err)} return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
} }
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID)
if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
}
senderIDString := string(senderID)
userDomain := userID.Domain()
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
Sender: userID, SenderID: senderIDString,
StateKey: &userID, StateKey: &senderIDString,
RoomID: req.RoomIDOrAlias, RoomID: req.RoomIDOrAlias,
Redacts: "", Redacts: "",
} }
@ -295,7 +300,7 @@ func (r *Joiner) performJoinRoomByID(
// is really no harm in just sending another membership event. // is really no harm in just sending another membership event.
membershipReq := &api.QueryMembershipForUserRequest{ membershipReq := &api.QueryMembershipForUserRequest{
RoomID: req.RoomIDOrAlias, RoomID: req.RoomIDOrAlias,
UserID: userID, UserID: userID.String(),
} }
membershipRes := &api.QueryMembershipForUserResponse{} membershipRes := &api.QueryMembershipForUserResponse{}
_ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes)
@ -372,22 +377,14 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
ctx context.Context, ctx context.Context,
joinReq *rsAPI.PerformJoinRequest, joinReq *rsAPI.PerformJoinRequest,
) (string, error) { ) (string, error) {
req := &api.QueryRestrictedJoinAllowedRequest{ roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias)
UserID: joinReq.UserID, if err != nil {
RoomID: joinReq.RoomIDOrAlias, return "", err
} }
res := &api.QueryRestrictedJoinAllowedResponse{} userID, err := spec.NewUserID(joinReq.UserID, true)
if err := r.Queryer.QueryRestrictedJoinAllowed(ctx, req, res); err != nil { if err != nil {
return "", fmt.Errorf("r.Queryer.QueryRestrictedJoinAllowed: %w", err) return "", err
} }
if !res.Restricted {
return "", nil return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID)
}
if !res.Resident {
return "", nil
}
if !res.Allowed {
return "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("the join to room %s was not allowed", joinReq.RoomIDOrAlias)}
}
return res.AuthorisedVia, nil
} }

View file

@ -152,11 +152,19 @@ func (r *Leaver) performLeaveRoomByID(
} }
// Prepare the template for the leave event. // Prepare the template for the leave event.
userID := req.UserID fullUserID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return nil, err
}
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID)
if err != nil {
return nil, err
}
senderIDString := string(senderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
Sender: userID, SenderID: senderIDString,
StateKey: &userID, StateKey: &senderIDString,
RoomID: req.RoomID, RoomID: req.RoomID,
Redacts: "", Redacts: "",
} }
@ -168,10 +176,7 @@ func (r *Leaver) performLeaveRoomByID(
} }
// Get the sender domain. // Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', proto.Sender) senderDomain := fullUserID.Domain()
if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", proto.Sender)
}
// We know that the user is in the room at this point so let's build // We know that the user is in the room at this point so let's build
// a leave event. // a leave event.

View file

@ -175,8 +175,16 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to get old room aliases: %w", err) return fmt.Errorf("Failed to get old room aliases: %w", err)
} }
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return fmt.Errorf("Failed to get userID: %w", err)
}
senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return fmt.Errorf("Failed to get senderID: %w", err)
}
for _, alias := range aliasRes.Aliases { for _, alias := range aliasRes.Aliases {
removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias} removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{} removeAliasRes := api.RemoveRoomAliasResponse{}
if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil {
return fmt.Errorf("Failed to remove old room alias: %w", err) return fmt.Errorf("Failed to remove old room alias: %w", err)
@ -287,7 +295,15 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
} }
// Check for power level required to send tombstone event (marks the current room as obsolete), // Check for power level required to send tombstone event (marks the current room as obsolete),
// if not found, use the StateDefault power level // if not found, use the StateDefault power level
return pl.UserLevel(userID) >= pl.EventLevel("m.room.tombstone", true) fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return false
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return false
}
return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true)
} }
// nolint:gocyclo // nolint:gocyclo
@ -383,7 +399,16 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
util.GetLogger(ctx).WithError(err).Error() util.GetLogger(ctx).WithError(err).Error()
return nil, fmt.Errorf("Power level event content was invalid") return nil, fmt.Errorf("Power level event content was invalid")
} }
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, userID)
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID)
// Now do the join rules event, same as the create and membership // Now do the join rules event, same as the create and membership
// events. We'll set a sane default of "invite" so that if the // events. We'll set a sane default of "invite" so that if the
@ -452,8 +477,16 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
for i, e := range eventsToMake { for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
fullUserID, userIDErr := spec.NewUserID(userID, true)
if userIDErr != nil {
return userIDErr
}
senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID)
if queryErr != nil {
return queryErr
}
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: string(senderID),
RoomID: newRoomID, RoomID: newRoomID,
Type: e.Type, Type: e.Type,
StateKey: &e.StateKey, StateKey: &e.StateKey,
@ -484,7 +517,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
} }
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
} }
@ -528,21 +563,26 @@ func (r *Upgrader) makeTombstoneEvent(
} }
func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) {
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Type: event.Type, Type: event.Type,
StateKey: &event.StateKey, StateKey: &event.StateKey,
} }
err := proto.SetContent(event.Content) err = proto.SetContent(event.Content)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err) return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err)
} }
// Get the sender domain. // Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', proto.Sender) senderDomain := fullUserID.Domain()
if serr != nil {
return nil, fmt.Errorf("Failed to split user ID %q: %w", proto.Sender, err)
}
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err)
@ -567,14 +607,16 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil { if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
} }
return headeredEvent, nil return headeredEvent, nil
} }
func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, userID string) (gomatrixserverlib.FledglingEvent, bool) { func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, senderID spec.SenderID) (gomatrixserverlib.FledglingEvent, bool) {
// Work out what power level we need in order to be able to send events // Work out what power level we need in order to be able to send events
// of all types into the room. // of all types into the room.
neededPowerLevel := powerLevelContent.StateDefault neededPowerLevel := powerLevelContent.StateDefault
@ -599,8 +641,8 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC
// If the user who is upgrading the room doesn't already have sufficient // If the user who is upgrading the room doesn't already have sufficient
// power, then elevate their power levels. // power, then elevate their power levels.
if tempPowerLevelContent.UserLevel(userID) < neededPowerLevel { if tempPowerLevelContent.UserLevel(senderID) < neededPowerLevel {
tempPowerLevelContent.Users[userID] = neededPowerLevel tempPowerLevelContent.Users[string(senderID)] = neededPowerLevel
powerLevelsOverridden = true powerLevelsOverridden = true
} }

View file

@ -17,10 +17,11 @@ package query
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
//"github.com/matrix-org/dendrite/roomserver/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -44,6 +45,42 @@ type Queryer struct {
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
IsLocalServerName func(spec.ServerName) bool IsLocalServerName func(spec.ServerName) bool
ServerACLs *acls.ServerACLs ServerACLs *acls.ServerACLs
Cfg *config.Dendrite
}
func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := r.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err
}
req := api.QueryServerJoinedToRoomRequest{
ServerName: localServerName,
RoomID: roomID.String(),
}
res := api.QueryServerJoinedToRoomResponse{}
if err = r.QueryServerJoinedToRoom(ctx, &req, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
}
userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
}
locallyJoinedUsers, err := r.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID))
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
}
return &gomatrixserverlib.RestrictedRoomJoinInfo{
LocalServerInRoom: res.RoomExists && res.IsInRoom,
UserJoinedToRoom: userJoinedToRoom,
JoinedUsers: locallyJoinedUsers,
}, nil
} }
// QueryLatestEventsAndState implements api.RoomserverInternalAPI // QueryLatestEventsAndState implements api.RoomserverInternalAPI
@ -122,7 +159,9 @@ func (r *Queryer) QueryStateAfterEvents(
} }
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
},
) )
if err != nil { if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
@ -349,7 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) sender := spec.UserID{}
userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if queryErr == nil && userID != nil {
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }
return nil return nil
@ -398,7 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) sender := spec.UserID{}
userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }
@ -588,7 +637,9 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState { if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
},
) )
if err != nil { if err != nil {
return err return err
@ -906,131 +957,28 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
} }
// nolint:gocyclo // nolint:gocyclo
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse) error { func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) {
// Look up if we know anything about the room. If it doesn't exist // Look up if we know anything about the room. If it doesn't exist
// or is a stub entry then we can't do anything. // or is a stub entry then we can't do anything.
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) roomInfo, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil { if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err) return "", fmt.Errorf("r.DB.RoomInfo: %w", err)
} }
if roomInfo == nil || roomInfo.IsStub() { if roomInfo == nil || roomInfo.IsStub() {
return nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID) return "", nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID)
} }
verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion)
if err != nil { if err != nil {
return err return "", err
} }
// If the room version doesn't allow restricted joins then don't
// try to process any further. return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID)
allowRestrictedJoins := verImpl.MayAllowRestrictedJoinsInEventAuth() }
if !allowRestrictedJoins {
return nil func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
} return r.DB.GetSenderIDForUser(ctx, roomID, userID)
// Start off by populating the "resident" flag in the response. If we }
// come across any rooms in the request that are missing, we will unset
// the flag. func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
res.Resident = true return r.DB.GetUserIDForSender(ctx, roomID, senderID)
// Get the join rules to work out if the join rule is "restricted".
joinRulesEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, spec.MRoomJoinRules, "")
if err != nil {
return fmt.Errorf("r.DB.GetStateEvent: %w", err)
}
if joinRulesEvent == nil {
return nil
}
var joinRules gomatrixserverlib.JoinRuleContent
if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
// If the join rule isn't "restricted" or "knock_restricted" then there's nothing more to do.
res.Restricted = joinRules.JoinRule == spec.Restricted || joinRules.JoinRule == spec.KnockRestricted
if !res.Restricted {
return nil
}
// If the user is already invited to the room then the join is allowed
// but we don't specify an authorised via user, since the event auth
// will allow the join anyway.
var pending bool
if pending, _, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil {
return fmt.Errorf("helpers.IsInvitePending: %w", err)
} else if pending {
res.Allowed = true
return nil
}
// We need to get the power levels content so that we can determine which
// users in the room are entitled to issue invites. We need to use one of
// these users as the authorising user.
powerLevelsEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, spec.MRoomPowerLevels, "")
if err != nil {
return fmt.Errorf("r.DB.GetStateEvent: %w", err)
}
powerLevels, err := powerLevelsEvent.PowerLevels()
if err != nil {
return fmt.Errorf("unable to get powerlevels: %w", err)
}
// Step through the join rules and see if the user matches any of them.
for _, rule := range joinRules.Allow {
// We only understand "m.room_membership" rules at this point in
// time, so skip any rule that doesn't match those.
if rule.Type != spec.MRoomMembership {
continue
}
// See if the room exists. If it doesn't exist or if it's a stub
// room entry then we can't check memberships.
targetRoomInfo, err := r.DB.RoomInfo(ctx, rule.RoomID)
if err != nil || targetRoomInfo == nil || targetRoomInfo.IsStub() {
res.Resident = false
continue
}
// First of all work out if *we* are still in the room, otherwise
// it's possible that the memberships will be out of date.
isIn, err := r.DB.GetLocalServerInRoom(ctx, targetRoomInfo.RoomNID)
if err != nil || !isIn {
// If we aren't in the room, we can no longer tell if the room
// memberships are up-to-date.
res.Resident = false
continue
}
// At this point we're happy that we are in the room, so now let's
// see if the target user is in the room.
_, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID)
if err != nil {
continue
}
// If the user is not in the room then we will skip them.
if !isIn {
continue
}
// The user is in the room, so now we will need to authorise the
// join using the user ID of one of our own users in the room. Pick
// one.
joinNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, targetRoomInfo.RoomNID, true, true)
if err != nil || len(joinNIDs) == 0 {
// There should always be more than one join NID at this point
// because we are gated behind GetLocalServerInRoom, but y'know,
// sometimes strange things happen.
continue
}
// For each of the joined users, let's see if we can get a valid
// membership event.
for _, joinNID := range joinNIDs {
events, err := r.DB.Events(ctx, roomInfo.RoomVersion, []types.EventNID{joinNID})
if err != nil || len(events) != 1 {
continue
}
event := events[0]
if event.Type() != spec.MRoomMember || event.StateKey() == nil {
continue // shouldn't happen
}
// Only users that have the power to invite should be chosen.
if powerLevels.UserLevel(*event.StateKey()) < powerLevels.Invite {
continue
}
res.Resident = true
res.Allowed = true
res.AuthorisedVia = *event.StateKey()
return nil
}
}
return nil
} }

View file

@ -60,7 +60,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu
"adds_state": len(update.NewRoomEvent.AddsStateEventIDs), "adds_state": len(update.NewRoomEvent.AddsStateEventIDs),
"removes_state": len(update.NewRoomEvent.RemovesStateEventIDs), "removes_state": len(update.NewRoomEvent.RemovesStateEventIDs),
"send_as_server": update.NewRoomEvent.SendAsServer, "send_as_server": update.NewRoomEvent.SendAsServer,
"sender": update.NewRoomEvent.Event.Sender(), "sender": update.NewRoomEvent.Event.SenderID(),
}) })
if update.NewRoomEvent.Event.StateKey() != nil { if update.NewRoomEvent.Event.StateKey() != nil {
logger = logger.WithField("state_key", *update.NewRoomEvent.Event.StateKey()) logger = logger.WithField("state_key", *update.NewRoomEvent.Event.StateKey())

View file

@ -392,7 +392,7 @@ func TestPurgeRoom(t *testing.T) {
type fledglingEvent struct { type fledglingEvent struct {
Type string Type string
StateKey *string StateKey *string
Sender string SenderID string
RoomID string RoomID string
Redacts string Redacts string
Depth int64 Depth int64
@ -405,7 +405,7 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEve
seed := make([]byte, ed25519.SeedSize) // zero seed seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed) key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
Sender: ev.Sender, SenderID: ev.SenderID,
Type: ev.Type, Type: ev.Type,
StateKey: ev.StateKey, StateKey: ev.StateKey,
RoomID: ev.RoomID, RoomID: ev.RoomID,
@ -444,7 +444,7 @@ func TestRedaction(t *testing.T) {
builderEv := mustCreateEvent(t, fledglingEvent{ builderEv := mustCreateEvent(t, fledglingEvent{
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Sender: alice.ID, SenderID: alice.ID,
RoomID: room.ID, RoomID: room.ID,
Redacts: redactedEvent.EventID(), Redacts: redactedEvent.EventID(),
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
@ -461,7 +461,7 @@ func TestRedaction(t *testing.T) {
builderEv := mustCreateEvent(t, fledglingEvent{ builderEv := mustCreateEvent(t, fledglingEvent{
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Sender: alice.ID, SenderID: alice.ID,
RoomID: room.ID, RoomID: room.ID,
Redacts: redactedEvent.EventID(), Redacts: redactedEvent.EventID(),
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
@ -478,7 +478,7 @@ func TestRedaction(t *testing.T) {
builderEv := mustCreateEvent(t, fledglingEvent{ builderEv := mustCreateEvent(t, fledglingEvent{
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Sender: bob.ID, SenderID: bob.ID,
RoomID: room.ID, RoomID: room.ID,
Redacts: redactedEvent.EventID(), Redacts: redactedEvent.EventID(),
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
@ -494,7 +494,7 @@ func TestRedaction(t *testing.T) {
builderEv := mustCreateEvent(t, fledglingEvent{ builderEv := mustCreateEvent(t, fledglingEvent{
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Sender: charlie.ID, SenderID: charlie.ID,
RoomID: room.ID, RoomID: room.ID,
Redacts: redactedEvent.EventID(), Redacts: redactedEvent.EventID(),
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
@ -598,16 +598,15 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
prepareRoomFunc func(t *testing.T) *test.Room prepareRoomFunc func(t *testing.T) *test.Room
wantResponse api.QueryRestrictedJoinAllowedResponse wantResponse string
wantError bool
}{ }{
{ {
name: "public room unrestricted", name: "public room unrestricted",
prepareRoomFunc: func(t *testing.T) *test.Room { prepareRoomFunc: func(t *testing.T) *test.Room {
return test.NewRoom(t, alice) return test.NewRoom(t, alice)
}, },
wantResponse: api.QueryRestrictedJoinAllowedResponse{ wantResponse: "",
Resident: true,
},
}, },
{ {
name: "room version without restrictions", name: "room version without restrictions",
@ -624,10 +623,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
}, test.WithStateKey("")) }, test.WithStateKey(""))
return r return r
}, },
wantResponse: api.QueryRestrictedJoinAllowedResponse{ wantError: true,
Resident: true,
Restricted: true,
},
}, },
{ {
name: "knock_restricted", name: "knock_restricted",
@ -638,10 +634,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
}, test.WithStateKey("")) }, test.WithStateKey(""))
return r return r
}, },
wantResponse: api.QueryRestrictedJoinAllowedResponse{ wantError: true,
Resident: true,
Restricted: true,
},
}, },
{ {
name: "restricted with pending invite", // bob should be allowed to join name: "restricted with pending invite", // bob should be allowed to join
@ -655,11 +648,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
}, test.WithStateKey(bob.ID)) }, test.WithStateKey(bob.ID))
return r return r
}, },
wantResponse: api.QueryRestrictedJoinAllowedResponse{ wantResponse: "",
Resident: true,
Restricted: true,
Allowed: true,
},
}, },
{ {
name: "restricted with allowed room_id, but missing room", // bob should not be allowed to join, as we don't know about the room name: "restricted with allowed room_id, but missing room", // bob should not be allowed to join, as we don't know about the room
@ -680,9 +669,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
}, test.WithStateKey(bob.ID)) }, test.WithStateKey(bob.ID))
return r return r
}, },
wantResponse: api.QueryRestrictedJoinAllowedResponse{ wantError: true,
Restricted: true,
},
}, },
{ {
name: "restricted with allowed room_id", // bob should be allowed to join, as we know about the room name: "restricted with allowed room_id", // bob should be allowed to join, as we know about the room
@ -703,12 +690,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
}, test.WithStateKey(bob.ID)) }, test.WithStateKey(bob.ID))
return r return r
}, },
wantResponse: api.QueryRestrictedJoinAllowedResponse{ wantResponse: alice.ID,
Resident: true,
Restricted: true,
Allowed: true,
AuthorisedVia: alice.ID,
},
}, },
} }
@ -738,16 +720,17 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
t.Errorf("failed to send events: %v", err) t.Errorf("failed to send events: %v", err)
} }
req := api.QueryRestrictedJoinAllowedRequest{ roomID, _ := spec.NewRoomID(testRoom.ID)
UserID: bob.ID, userID, _ := spec.NewUserID(bob.ID, true)
RoomID: testRoom.ID, got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID)
if tc.wantError && err == nil {
t.Fatal("expected error, got none")
} }
res := api.QueryRestrictedJoinAllowedResponse{} if !tc.wantError && err != nil {
if err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), &req, &res); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(tc.wantResponse, res) { if !reflect.DeepEqual(tc.wantResponse, got) {
t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, res) t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, got)
} }
}) })
} }

View file

@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -43,6 +44,7 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
} }
type StateResolution struct { type StateResolution struct {
@ -945,7 +947,9 @@ func (v *StateResolution) resolveConflictsV1(
} }
// Resolve the conflicts. // Resolve the conflicts.
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return v.db.GetUserIDForSender(ctx, roomID, senderID)
})
// Map from the full events back to numeric state entries. // Map from the full events back to numeric state entries.
for _, resolvedEvent := range resolvedEvents { for _, resolvedEvent := range resolvedEvents {
@ -1057,6 +1061,9 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents, conflictedEvents,
nonConflictedEvents, nonConflictedEvents,
authEvents, authEvents,
func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return v.db.GetUserIDForSender(ctx, roomID, senderID)
},
) )
}() }()

View file

@ -166,6 +166,10 @@ type Database interface {
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
// GetKnownUsers searches all users that userID knows about. // GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// GetKnownUsers tries to obtain the current mxid for a given user.
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
// GetKnownUsers tries to obtain the current senderID for a given user.
GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error) GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
@ -211,6 +215,7 @@ type RoomDatabase interface {
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
} }
type EventDatabase interface { type EventDatabase interface {

View file

@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event *
var inserted bool // Did the query result in a membership change? var inserted bool // Did the query result in a membership change?
var retired []string // Did we retire any updates in the process? var retired []string // Did we retire any updates in the process?
return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, string(event.SenderID()))
if err != nil { if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
} }

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -250,3 +251,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }
func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return u.d.GetUserIDForSender(ctx, roomID, senderID)
}

View file

@ -988,8 +988,18 @@ func (d *EventDatabase) MaybeRedactEvent(
return nil return nil
} }
_, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) // TODO: Don't hack senderID into userID here (pseudoIDs)
_, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) sender1Domain := ""
sender1, err1 := spec.NewUserID(string(redactedEvent.SenderID()), true)
if err1 == nil {
sender1Domain = string(sender1.Domain())
}
// TODO: Don't hack senderID into userID here (pseudoIDs)
sender2Domain := ""
sender2, err2 := spec.NewUserID(string(redactionEvent.SenderID()), true)
if err2 == nil {
sender2Domain = string(sender2.Domain())
}
var powerlevels *gomatrixserverlib.PowerLevelContent var powerlevels *gomatrixserverlib.PowerLevelContent
powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID())
if err != nil { if err != nil {
@ -997,9 +1007,9 @@ func (d *EventDatabase) MaybeRedactEvent(
} }
switch { switch {
case powerlevels.UserLevel(redactionEvent.Sender()) >= powerlevels.Redact: case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact:
// 1. The power level of the redaction events sender is greater than or equal to the redact level. // 1. The power level of the redaction events sender is greater than or equal to the redact level.
case sender1 == sender2: case sender1Domain == sender2Domain:
// 2. The domain of the redaction events sender matches that of the original events sender. // 2. The domain of the redaction events sender matches that of the original events sender.
default: default:
ignoreRedaction = true ignoreRedaction = true
@ -1514,6 +1524,16 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
} }
func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
// TODO: Use real logic once DB for pseudoIDs is in place
return spec.NewUserID(string(senderID), true)
}
func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
// TODO: Use real logic once DB for pseudoIDs is in place
return spec.SenderID(userID.String()), nil
}
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)

View file

@ -92,9 +92,11 @@ type MSC2836EventRelationshipsResponse struct {
ParsedAuthChain []gomatrixserverlib.PDU ParsedAuthChain []gomatrixserverlib.PDU
} }
func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
out := &EventRelationshipResponse{ out := &EventRelationshipResponse{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll), Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
Limited: res.Limited, Limited: res.Limited,
NextBatch: res.NextBatch, NextBatch: res.NextBatch,
} }
@ -187,7 +189,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: toClientResponse(res), JSON: toClientResponse(req.Context(), res, rsAPI),
} }
} }
} }

View file

@ -525,6 +525,10 @@ type testRoomserverAPI struct {
events map[string]*types.HeaderedEvent events map[string]*types.HeaderedEvent
} }
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
for _, eventID := range req.EventIDs { for _, eventID := range req.EventIDs {
ev := r.events[eventID] ev := r.events[eventID]
@ -586,7 +590,7 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEve
seed := make([]byte, ed25519.SeedSize) // zero seed seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed) key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
Sender: ev.Sender, SenderID: ev.Sender,
Depth: 999, Depth: 999,
Type: ev.Type, Type: ev.Type,
StateKey: ev.StateKey, StateKey: ev.StateKey,

View file

@ -730,7 +730,7 @@ func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent {
Type: ev.Type(), Type: ev.Type(),
StateKey: *ev.StateKey(), StateKey: *ev.StateKey(),
Content: ev.Content(), Content: ev.Content(),
Sender: ev.Sender(), Sender: string(ev.SenderID()),
OriginServerTS: ev.OriginServerTS(), OriginServerTS: ev.OriginServerTS(),
} }
} }

View file

@ -523,7 +523,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
prev := types.PrevEventRef{ prev := types.PrevEventRef{
PrevContent: prevEvent.Content(), PrevContent: prevEvent.Content(),
ReplacesState: prevEvent.EventID(), ReplacesState: prevEvent.EventID(),
PrevSender: prevEvent.Sender(), PrevSenderID: string(prevEvent.SenderID()),
} }
event.PDU, err = event.SetUnsigned(prev) event.PDU, err = event.SetUnsigned(prev)

View file

@ -193,14 +193,20 @@ func Context(
} }
} }
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll) eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
newState := state newState := state
if filter.LazyLoadMembers { if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent) allEvents = append(allEvents, &requestedEvent)
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll) evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to load membership events") logrus.WithError(err).Error("unable to load membership events")
@ -211,12 +217,19 @@ func Context(
} }
} }
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll) sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
if err == nil && userID != nil {
sender = *userID
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
response := ContextRespsonse{ response := ContextRespsonse{
Event: &ev, Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient, EventsBefore: eventsBeforeClient,
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll), State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
} }
if len(response.State) > filter.Limit { if len(response.State) > filter.Limit {

View file

@ -101,8 +101,13 @@ func GetEvent(
} }
} }
sender := spec.UserID{}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID())
if err == nil && senderUserID != nil {
sender = *senderUserID
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll), JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
} }
} }

View file

@ -144,7 +144,22 @@ func GetMemberships(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
res.Joined[ev.Sender()] = joinedMember(content)
userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
if err != nil || userID == nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
res.Joined[userID.String()] = joinedMember(content)
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -153,6 +168,8 @@ func GetMemberships(
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll)}, JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})},
} }
} }

View file

@ -241,7 +241,7 @@ func OnIncomingMessagesRequest(
device: device, device: device,
} }
clientEvents, start, end, err := mReq.retrieveEvents() clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed")
return util.JSONResponse{ return util.JSONResponse{
@ -273,7 +273,9 @@ func OnIncomingMessagesRequest(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...) res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})...)
} }
// If we didn't return any events, set the end to an empty string, so it will be omitted // If we didn't return any events, set the end to an empty string, so it will be omitted
@ -310,7 +312,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.
// homeserver in the room for older events. // homeserver in the room for older events.
// Returns an error if there was an issue talking to the database or with the // Returns an error if there was an issue talking to the database or with the
// remote homeserver. // remote homeserver.
func (r *messagesReq) retrieveEvents() ( func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserverAPI) (
clientEvents []synctypes.ClientEvent, start, clientEvents []synctypes.ClientEvent, start,
end types.TopologyToken, err error, end types.TopologyToken, err error,
) { ) {
@ -383,7 +385,9 @@ func (r *messagesReq) retrieveEvents() (
"events_before": len(events), "events_before": len(events),
"events_after": len(filteredEvents), "events_after": len(filteredEvents),
}).Debug("applied history visibility (messages)") }).Debug("applied history visibility (messages)")
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll), start, end, err return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), start, end, err
} }
func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) { func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) {
@ -491,7 +495,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
} }
// Append the events ve previously retrieved locally. // Append the events ve previously retrieved locally.
events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...) events = append(events, r.snapshot.StreamEventsToEvents(r.ctx, nil, streamEvents, r.rsAPI)...)
sort.Sort(eventsByDepth(events)) sort.Sort(eventsByDepth(events))
return return

View file

@ -114,9 +114,14 @@ func Relations(
// type if it was specified. // type if it was specified.
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents { for _, event := range filteredEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
res.Chunk = append( res.Chunk = append(
res.Chunk, res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll), synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
) )
} }

View file

@ -171,7 +171,7 @@ func Setup(
nb := req.FormValue("next_batch") nb := req.FormValue("next_batch")
nextBatch = &nb nextBatch = &nb
} }
return Search(req, device, syncDB, fts, nextBatch) return Search(req, device, syncDB, fts, nextBatch, rsAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
@ -38,7 +39,7 @@ import (
) )
// nolint:gocyclo // nolint:gocyclo
func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse { func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string, rsAPI roomserverAPI.SyncRoomserverAPI) util.JSONResponse {
start := time.Now() start := time.Now()
var ( var (
searchReq SearchRequest searchReq SearchRequest
@ -204,11 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos := make(map[string]ProfileInfoResponse) profileInfos := make(map[string]ProfileInfoResponse)
for _, ev := range append(eventsBefore, eventsAfter...) { for _, ev := range append(eventsBefore, eventsAfter...) {
profile, ok := knownUsersProfiles[event.Sender()] userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
if queryErr != nil {
logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
continue
}
profile, ok := knownUsersProfiles[userID.String()]
if !ok { if !ok {
stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.Sender()) stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, string(ev.SenderID()))
if err != nil { if stateErr != nil {
logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
continue continue
} }
if stateEvent == nil { if stateEvent == nil {
@ -218,21 +225,30 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str,
DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str,
} }
knownUsersProfiles[event.Sender()] = profile knownUsersProfiles[userID.String()] = profile
} }
profileInfos[ev.Sender()] = profile profileInfos[userID.String()] = profile
} }
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
results = append(results, Result{ results = append(results, Result{
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
End: endToken.String(), End: endToken.String(),
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync), EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync), return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
ProfileInfo: profileInfos, }),
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
ProfileInfo: profileInfos,
}, },
Rank: eventScore[event.EventID()].Score, Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll), Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
}) })
roomGroup := groups[event.RoomID()] roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID()) roomGroup.Results = append(roomGroup.Results, event.EventID())
@ -247,7 +263,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync) stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})
} }
} }

View file

@ -2,6 +2,7 @@ package routing
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -9,6 +10,7 @@ import (
"github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types" rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
@ -21,6 +23,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func TestSearch(t *testing.T) { func TestSearch(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceDevice := userapi.Device{UserID: alice.ID} aliceDevice := userapi.Device{UserID: alice.ID}
@ -247,7 +255,7 @@ func TestSearch(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", reqBody) req := httptest.NewRequest(http.MethodPost, "/", reqBody)
res := Search(req, tc.device, db, fts, tc.from) res := Search(req, tc.device, db, fts, tc.from, &FakeSyncRoomserverAPI{})
if !tc.wantOK && !res.Is2xx() { if !tc.wantOK && !res.Is2xx() {
return return
} }

View file

@ -44,8 +44,8 @@ type DatabaseTransaction interface {
MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
@ -90,7 +90,7 @@ type DatabaseTransaction interface {
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// relevant events within the given ranges for the supplied user ID and device ID. // relevant events within the given ranges for the supplied user ID and device ID.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)

View file

@ -343,7 +343,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
event.RoomID(), event.RoomID(),
event.EventID(), event.EventID(),
event.Type(), event.Type(),
event.Sender(), event.SenderID(),
containsURL, containsURL,
*event.StateKey(), *event.StateKey(),
headeredJSON, headeredJSON,

View file

@ -407,7 +407,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
event.EventID(), event.EventID(),
headeredJSON, headeredJSON,
event.Type(), event.Type(),
event.Sender(), event.SenderID(),
containsURL, containsURL,
pq.StringArray(addState), pq.StringArray(addState),
pq.StringArray(removeState), pq.StringArray(removeState),

View file

@ -99,7 +99,41 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*rstypes.He
// We don't include a device here as we only include transaction IDs in // We don't include a device here as we only include transaction IDs in
// incremental syncs. // incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil
}
func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent {
out := make([]*rstypes.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].HeaderedEvent
if device != nil && in[i].TransactionID != nil {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
}
return out
} }
// AddInviteEvent stores a new invite event for a user. // AddInviteEvent stores a new invite event for a user.
@ -190,26 +224,6 @@ func (d *Database) UpsertAccountData(
return return
} }
func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent {
out := make([]*rstypes.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].HeaderedEvent
if device != nil && in[i].TransactionID != nil {
if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
}
return out
}
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.

View file

@ -10,6 +10,7 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types" rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -186,7 +187,7 @@ func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]
// We don't include a device here as we only include transaction IDs in // We don't include a device here as we only include transaction IDs in
// incremental syncs. // incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil
} }
func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
@ -325,7 +326,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos(
func (d *DatabaseTransaction) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *synctypes.StateFilter, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI,
) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
@ -417,7 +418,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
if !peek.Deleted { if !peek.Deleted {
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: spec.Peek, Membership: spec.Peek,
StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), StateEvents: d.StreamEventsToEvents(ctx, device, state[peek.RoomID], rsAPI),
RoomID: peek.RoomID, RoomID: peek.RoomID,
}) })
} }
@ -462,7 +463,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: membership, Membership: membership,
MembershipPos: ev.StreamPosition, MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[roomID], rsAPI),
RoomID: roomID, RoomID: roomID,
}) })
break break
@ -474,7 +475,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: spec.Join, Membership: spec.Join,
StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[joinedRoomID], rsAPI),
RoomID: joinedRoomID, RoomID: joinedRoomID,
NewlyJoined: newlyJoinedRooms[joinedRoomID], NewlyJoined: newlyJoinedRooms[joinedRoomID],
}) })
@ -490,7 +491,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *synctypes.StateFilter, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI,
) ([]types.StateDelta, []string, error) { ) ([]types.StateDelta, []string, error) {
// Look up all memberships for the user. We only care about rooms that a // Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left. // user has ever interacted with — joined to, kicked/banned from, left.
@ -531,7 +532,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
} }
deltas[peek.RoomID] = types.StateDelta{ deltas[peek.RoomID] = types.StateDelta{
Membership: spec.Peek, Membership: spec.Peek,
StateEvents: d.StreamEventsToEvents(device, s), StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI),
RoomID: peek.RoomID, RoomID: peek.RoomID,
} }
} }
@ -560,7 +561,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
deltas[roomID] = types.StateDelta{ deltas[roomID] = types.StateDelta{
Membership: membership, Membership: membership,
MembershipPos: ev.StreamPosition, MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), StateEvents: d.StreamEventsToEvents(ctx, device, stateStreamEvents, rsAPI),
RoomID: roomID, RoomID: roomID,
} }
} }
@ -581,7 +582,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
} }
deltas[joinedRoomID] = types.StateDelta{ deltas[joinedRoomID] = types.StateDelta{
Membership: spec.Join, Membership: spec.Join,
StateEvents: d.StreamEventsToEvents(device, s), StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI),
RoomID: joinedRoomID, RoomID: joinedRoomID,
} }
} }

View file

@ -342,7 +342,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
event.RoomID(), event.RoomID(),
event.EventID(), event.EventID(),
event.Type(), event.Type(),
event.Sender(), event.SenderID(),
containsURL, containsURL,
*event.StateKey(), *event.StateKey(),
headeredJSON, headeredJSON,

View file

@ -348,7 +348,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
event.EventID(), event.EventID(),
headeredJSON, headeredJSON,
event.Type(), event.Type(),
event.Sender(), event.SenderID(),
containsURL, containsURL,
string(addStateJSON), string(addStateJSON),
string(removeStateJSON), string(removeStateJSON),

View file

@ -214,7 +214,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
} }
gots := snapshot.StreamEventsToEvents(nil, paginatedEvents) gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
}) })
}) })

View file

@ -10,6 +10,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -17,6 +18,7 @@ import (
type InviteStreamProvider struct { type InviteStreamProvider struct {
DefaultStreamProvider DefaultStreamProvider
rsAPI api.SyncRoomserverAPI
} }
func (p *InviteStreamProvider) Setup( func (p *InviteStreamProvider) Setup(
@ -62,11 +64,17 @@ func (p *InviteStreamProvider) IncrementalSync(
} }
for roomID, inviteEvent := range invites { for roomID, inviteEvent := range invites {
user := spec.UserID{}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID())
if err == nil && sender != nil {
user = *sender
}
// skip ignored user events // skip ignored user events
if _, ok := req.IgnoredUsers.List[inviteEvent.Sender()]; ok { if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue continue
} }
ir := types.NewInviteResponse(inviteEvent) ir := types.NewInviteResponse(inviteEvent, user)
req.Response.Rooms.Invite[roomID] = ir req.Response.Rooms.Invite[roomID] = ir
} }

View file

@ -175,12 +175,12 @@ func (p *PDUStreamProvider) IncrementalSync(
eventFilter := req.Filter.Room.Timeline eventFilter := req.Filter.Room.Timeline
if req.WantFullState { if req.WantFullState {
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
return from return from
} }
} else { } else {
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return from return from
} }
@ -275,7 +275,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
limited := dbEvents[delta.RoomID].Limited limited := dbEvents[delta.RoomID].Limited
recEvents := gomatrixserverlib.ReverseTopologicalOrdering( recEvents := gomatrixserverlib.ReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(device, recentStreamEvents)), gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)),
gomatrixserverlib.TopologicalOrderByPrevEvents, gomatrixserverlib.TopologicalOrderByPrevEvents,
) )
recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents)) recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents))
@ -376,20 +376,28 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
} }
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
req.Response.Rooms.Join[delta.RoomID] = jr req.Response.Rooms.Join[delta.RoomID] = jr
case spec.Peek: case spec.Peek:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms // TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync) jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
req.Response.Rooms.Peek[delta.RoomID] = jr req.Response.Rooms.Peek[delta.RoomID] = jr
case spec.Leave: case spec.Leave:
@ -398,11 +406,15 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case spec.Ban: case spec.Ban:
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
req.Response.Rooms.Leave[delta.RoomID] = lr req.Response.Rooms.Leave[delta.RoomID] = lr
} }
@ -425,7 +437,7 @@ func applyHistoryVisibilityFilter(
for _, ev := range recentEvents { for _, ev := range recentEvents {
if ev.StateKey() != nil { if ev.StateKey() != nil {
stateTypes = append(stateTypes, ev.Type()) stateTypes = append(stateTypes, ev.Type())
senders = append(senders, ev.Sender()) senders = append(senders, string(ev.SenderID()))
} }
} }
@ -500,7 +512,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check. // "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) recentEvents := snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)
events := recentEvents events := recentEvents
// Only apply history visibility checks if the response is for joined rooms // Only apply history visibility checks if the response is for joined rooms
@ -552,11 +564,15 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync) jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
return jr, nil return jr, nil
} }
@ -577,8 +593,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// Add all users the client doesn't know about yet to a list // Add all users the client doesn't know about yet to a list
for _, event := range timelineEvents { for _, event := range timelineEvents {
// Membership is not yet cached, add it to the list // Membership is not yet cached, add it to the list
if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok { if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, string(event.SenderID())); !ok {
timelineUsers[event.Sender()] = struct{}{} timelineUsers[string(event.SenderID())] = struct{}{}
} }
} }
// Preallocate with the same amount, even if it will end up with fewer values // Preallocate with the same amount, even if it will end up with fewer values

View file

@ -45,6 +45,7 @@ func NewSyncStreamProviders(
}, },
InviteStreamProvider: &InviteStreamProvider{ InviteStreamProvider: &InviteStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},
rsAPI: rsAPI,
}, },
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d}, DefaultStreamProvider: DefaultStreamProvider{DB: d},

View file

@ -40,6 +40,10 @@ type syncRoomserverAPI struct {
rooms []*test.Room rooms []*test.Room
} }
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {
var room *test.Room var room *test.Room
for _, r := range s.rooms { for _, r := range s.rooms {

View file

@ -44,22 +44,27 @@ type ClientEvent struct {
} }
// ToClientEvents converts server events to client events. // ToClientEvents converts server events to client events.
func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) []ClientEvent { func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) []ClientEvent {
evs := make([]ClientEvent, 0, len(serverEvs)) evs := make([]ClientEvent, 0, len(serverEvs))
for _, se := range serverEvs { for _, se := range serverEvs {
if se == nil { if se == nil {
continue // TODO: shouldn't happen? continue // TODO: shouldn't happen?
} }
evs = append(evs, ToClientEvent(se, format)) sender := spec.UserID{}
userID, err := userIDForSender(se.RoomID(), se.SenderID())
if err == nil && userID != nil {
sender = *userID
}
evs = append(evs, ToClientEvent(se, format, sender))
} }
return evs return evs
} }
// ToClientEvent converts a single server event to a client event. // ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent { func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
ce := ClientEvent{ ce := ClientEvent{
Content: spec.RawJSON(se.Content()), Content: spec.RawJSON(se.Content()),
Sender: se.Sender(), Sender: sender.String(),
Type: se.Type(), Type: se.Type(),
StateKey: se.StateKey(), StateKey: se.StateKey(),
Unsigned: spec.RawJSON(se.Unsigned()), Unsigned: spec.RawJSON(se.Unsigned()),

View file

@ -21,6 +21,7 @@ import (
"testing" "testing"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
func TestToClientEvent(t *testing.T) { // nolint: gocyclo func TestToClientEvent(t *testing.T) { // nolint: gocyclo
@ -43,7 +44,11 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil { if err != nil {
t.Fatalf("failed to create Event: %s", err) t.Fatalf("failed to create Event: %s", err)
} }
ce := ToClientEvent(ev, FormatAll) userID, err := spec.NewUserID("@test:localhost", true)
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatAll, *userID)
if ce.EventID != ev.EventID() { if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
} }
@ -62,8 +67,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { if !bytes.Equal(ce.Unsigned, ev.Unsigned()) {
t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned))
} }
if ce.Sender != ev.Sender() { if ce.Sender != userID.String() {
t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.Sender(), ce.Sender) t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender)
} }
j, err := json.Marshal(ce) j, err := json.Marshal(ce)
if err != nil { if err != nil {
@ -98,7 +103,11 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create Event: %s", err) t.Fatalf("failed to create Event: %s", err)
} }
ce := ToClientEvent(ev, FormatSync) userID, err := spec.NewUserID("@test:localhost", true)
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatSync, *userID)
if ce.RoomID != "" { if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
} }

View file

@ -343,7 +343,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
type PrevEventRef struct { type PrevEventRef struct {
PrevContent json.RawMessage `json:"prev_content"` PrevContent json.RawMessage `json:"prev_content"`
ReplacesState string `json:"replaces_state"` ReplacesState string `json:"replaces_state"`
PrevSender string `json:"prev_sender"` PrevSenderID string `json:"prev_sender"`
} }
type DeviceLists struct { type DeviceLists struct {
@ -539,7 +539,7 @@ type InviteResponse struct {
} }
// NewInviteResponse creates an empty response with initialised arrays. // NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse { func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
res := InviteResponse{} res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{} res.InviteState.Events = []json.RawMessage{}
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse {
// Then we'll see if we can create a partial of the invite event itself. // Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite. // This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync) inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
inviteEvent.Unsigned = nil inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil { if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev) res.InviteState.Events = append(res.InviteState.Events, ev)

View file

@ -8,8 +8,13 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestSyncTokens(t *testing.T) { func TestSyncTokens(t *testing.T) {
shouldPass := map[string]string{ shouldPass := map[string]string{
"s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(), "s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(),
@ -56,7 +61,12 @@ func TestNewInviteResponse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}) sender, err := spec.NewUserID("@neilalexander:matrix.org", true)
if err != nil {
t.Fatal(err)
}
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
j, err := json.Marshal(res) j, err := json.Marshal(res)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -39,6 +39,10 @@ var (
roomIDCounter = int64(0) roomIDCounter = int64(0)
) )
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
type Room struct { type Room struct {
ID string ID string
Version gomatrixserverlib.RoomVersion Version gomatrixserverlib.RoomVersion
@ -164,7 +168,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
} }
builder := gomatrixserverlib.MustGetRoomVersion(r.Version).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ builder := gomatrixserverlib.MustGetRoomVersion(r.Version).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
Sender: creator.ID, SenderID: creator.ID,
RoomID: r.ID, RoomID: r.ID,
Type: eventType, Type: eventType,
StateKey: mod.stateKey, StateKey: mod.stateKey,
@ -195,7 +199,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
if err != nil { if err != nil {
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err) t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
} }
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
} }
headeredEvent := &rstypes.HeaderedEvent{PDU: ev} headeredEvent := &rstypes.HeaderedEvent{PDU: ev}

View file

@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
} }
if s.cfg.Matrix.ReportStats.Enabled { if s.cfg.Matrix.ReportStats.Enabled {
go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID()) go s.storeMessageStats(ctx, event.Type(), string(event.SenderID()), event.RoomID())
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -301,7 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch { switch {
case event.Type() == spec.MRoomMember: case event.Type() == spec.MRoomMember:
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll) sender := spec.UserID{}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if queryErr == nil && userID != nil {
sender = *userID
}
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
var member *localMembership var member *localMembership
member, err = newLocalMembership(&cevent) member, err = newLocalMembership(&cevent)
if err != nil { if err != nil {
@ -529,12 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
return fmt.Errorf("s.localPushDevices: %w", err) return fmt.Errorf("s.localPushDevices: %w", err)
} }
sender := spec.UserID{}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
n := &api.Notification{ n := &api.Notification{
Actions: actions, Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the // UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which // fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync. // matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync), Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
// TODO: this is per-device, but it's not part of the primary // TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't // key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it // make sense. What is this supposed to be? Sytests require it
@ -615,7 +625,12 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// evaluatePushRules fetches and evaluates the push rules of a local // evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify). // user. Returns actions (including dont_notify).
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID { user := ""
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil {
user = sender.String()
}
if user == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for // SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves. // events that the user has sent themselves.
return nil, nil return nil, nil
@ -632,9 +647,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
if err != nil { if err != nil {
return nil, err return nil, err
} }
sender := event.Sender() if _, ok := ignored.List[sender.String()]; ok {
if _, ok := ignored.List[sender]; ok { return nil, fmt.Errorf("user %s is ignored", sender.String())
return nil, fmt.Errorf("user %s is ignored", sender)
} }
} }
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain) ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
@ -650,7 +664,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize, roomSize: roomSize,
} }
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
rule, err := eval.MatchEvent(event.PDU) rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -682,7 +698,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display
func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { func (rse *ruleSetEvalContext) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) {
req := &rsapi.QueryLatestEventsAndStateRequest{ req := &rsapi.QueryLatestEventsAndStateRequest{
RoomID: rse.roomID, RoomID: rse.roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{ StateToFetch: []gomatrixserverlib.StateKeyTuple{
@ -702,7 +718,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
if err != nil { if err != nil {
return false, err return false, err
} }
return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil return plc.UserLevel(senderID) >= plc.NotificationLevel(levelKey), nil
} }
return true, nil return true, nil
} }
@ -756,6 +772,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
} }
default: default:
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
return nil, err
}
req = pushgateway.NotifyRequest{ req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{ Notification: pushgateway.Notification{
Content: event.Content(), Content: event.Content(),
@ -767,7 +788,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
ID: event.EventID(), ID: event.EventID(),
RoomID: event.RoomID(), RoomID: event.RoomID(),
RoomName: roomName, RoomName: roomName,
Sender: event.Sender(), Sender: sender.String(),
Type: event.Type(), Type: event.Type(),
}, },
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
return &types.HeaderedEvent{PDU: ev} return &types.HeaderedEvent{PDU: ev}
} }
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func Test_evaluatePushRules(t *testing.T) { func Test_evaluatePushRules(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
consumer := OutputRoomEventConsumer{db: db} consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}}
testCases := []struct { testCases := []struct {
name string name string
@ -86,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) {
}, },
{ {
name: "m.room.message highlights", name: "m.room.message highlights",
eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`, eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`,
wantNotify: true, wantNotify: true,
wantAction: pushrules.NotifyAction, wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{ wantActions: []*pushrules.Action{

View file

@ -11,6 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -87,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
} }
// Prepare pusher with our test server URL // Prepare pusher with our test server URL
if err := db.UpsertPusher(ctx, api.Pusher{ if err = db.UpsertPusher(ctx, api.Pusher{
Kind: api.HTTPKind, Kind: api.HTTPKind,
AppID: appID, AppID: appID,
PushKey: pushKey, PushKey: pushKey,
@ -99,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) {
} }
// Insert a dummy event // Insert a dummy event
sender, err := spec.NewUserID(alice.ID, true)
if err != nil {
t.Error(err)
}
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll), Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
}); err != nil { }); err != nil {
t.Error(err) t.Error(err)
} }