Cleanup remaining statekey usage for senderIDs (#3106)

This commit is contained in:
devonh 2023-06-12 11:19:25 +00:00 committed by GitHub
parent 832ccc32f6
commit 77d9e4e93d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
62 changed files with 760 additions and 455 deletions

View file

@ -145,8 +145,16 @@ func SaveReadMarker(
userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string,
) util.JSONResponse { ) util.JSONResponse {
deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("userID for this device is invalid"),
}
}
// Verify that the user is a member of this room // Verify that the user is a member of this room
resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -55,9 +55,16 @@ func GetAliases(
visibility = content.HistoryVisibility visibility = content.HistoryVisibility
} }
if visibility != spec.WorldReadable { if visibility != spec.WorldReadable {
deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
queryReq := api.QueryMembershipForUserRequest{ queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, UserID: *deviceUserID,
} }
var queryRes api.QueryMembershipForUserResponse var queryRes api.QueryMembershipForUserResponse
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {

View file

@ -224,6 +224,7 @@ func createRoom(
PrivateKey: privateKey, PrivateKey: privateKey,
EventTime: evTime, EventTime: evTime,
} }
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req) roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
if createRes != nil { if createRes != nil {
return *createRes return *createRes

View file

@ -314,7 +314,22 @@ func SetVisibility(
req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID) deviceUserID, err := spec.NewUserID(dev.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("userID for this device is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("failed to find senderID for this user"),
}
}
resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
@ -327,7 +342,7 @@ func SetVisibility(
}}, }},
} }
var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse
err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes)
if err != nil || len(queryEventsRes.StateEvents) == 0 { if err != nil || len(queryEventsRes.StateEvents) == 0 {
util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") util.GetLogger(req.Context()).WithError(err).Error("could not query events from room")
return util.JSONResponse{ return util.JSONResponse{
@ -338,20 +353,6 @@ func SetVisibility(
// NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event
power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU)
fullUserID, err := spec.NewUserID(dev.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View file

@ -29,10 +29,18 @@ func LeaveRoomByID(
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("device userID is invalid"),
}
}
// Prepare to ask the roomserver to perform the room join. // Prepare to ask the roomserver to perform the room join.
leaveReq := roomserverAPI.PerformLeaveRequest{ leaveReq := roomserverAPI.PerformLeaveRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, Leaver: *userID,
} }
leaveRes := roomserverAPI.PerformLeaveResponse{} leaveRes := roomserverAPI.PerformLeaveResponse{}

View file

@ -57,7 +57,22 @@ func SendBan(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
}
}
errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
@ -66,20 +81,6 @@ func SendBan(
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
}
}
allowedToBan := pl.UserLevel(senderID) >= pl.Ban allowedToBan := pl.UserLevel(senderID) >= pl.Ban
if !allowedToBan { if !allowedToBan {
return util.JSONResponse{ return util.JSONResponse{
@ -147,7 +148,22 @@ func SendKick(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
@ -156,20 +172,6 @@ func SendKick(
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
allowedToKick := pl.UserLevel(senderID) >= pl.Kick allowedToKick := pl.UserLevel(senderID) >= pl.Kick
if !allowedToKick { if !allowedToKick {
return util.JSONResponse{ return util.JSONResponse{
@ -178,10 +180,17 @@ func SendKick(
} }
} }
bodyUserID, err := spec.NewUserID(body.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("body userID is invalid"),
}
}
var queryRes roomserverAPI.QueryMembershipForUserResponse var queryRes roomserverAPI.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: body.UserID, UserID: *bodyUserID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -213,15 +222,30 @@ func SendUnban(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
bodyUserID, err := spec.NewUserID(body.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("body userID is invalid"),
}
}
var queryRes roomserverAPI.QueryMembershipForUserResponse var queryRes roomserverAPI.QueryMembershipForUserResponse
err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: body.UserID, UserID: *bodyUserID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -272,7 +296,15 @@ func SendInvite(
} }
} }
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
@ -340,17 +372,18 @@ func sendInvite(
func buildMembershipEventDirect( func buildMembershipEventDirect(
ctx context.Context, ctx context.Context,
targetUserID, reason string, userDisplayName, userAvatarURL string, targetSenderID spec.SenderID, reason string, userDisplayName, userAvatarURL string,
sender string, senderDomain spec.ServerName, sender spec.SenderID, senderDomain spec.ServerName,
membership, roomID string, isDirect bool, membership, roomID string, isDirect bool,
keyID gomatrixserverlib.KeyID, privateKey ed25519.PrivateKey, evTime time.Time, keyID gomatrixserverlib.KeyID, privateKey ed25519.PrivateKey, evTime time.Time,
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
) (*types.HeaderedEvent, error) { ) (*types.HeaderedEvent, error) {
targetSenderString := string(targetSenderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: sender, SenderID: string(sender),
RoomID: roomID, RoomID: roomID,
Type: "m.room.member", Type: "m.room.member",
StateKey: &targetUserID, StateKey: &targetSenderString,
} }
content := gomatrixserverlib.MemberContent{ content := gomatrixserverlib.MemberContent{
@ -391,8 +424,25 @@ func buildMembershipEvent(
return nil, err return nil, err
} }
return buildMembershipEventDirect(ctx, targetUserID, reason, profile.DisplayName, profile.AvatarURL, userID, err := spec.NewUserID(device.UserID, true)
device.UserID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
if err != nil {
return nil, err
}
targetID, err := spec.NewUserID(targetUserID, true)
if err != nil {
return nil, err
}
targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID)
if err != nil {
return nil, err
}
return buildMembershipEventDirect(ctx, targetSenderID, reason, profile.DisplayName, profile.AvatarURL,
senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI)
} }
// loadProfile lookups the profile of a given user from the database and returns // loadProfile lookups the profile of a given user from the database and returns
@ -490,7 +540,7 @@ func checkAndProcessThreepid(
return return
} }
func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse { func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID spec.UserID, roomID string) *util.JSONResponse {
var membershipRes roomserverAPI.QueryMembershipForUserResponse var membershipRes roomserverAPI.QueryMembershipForUserResponse
err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
@ -518,12 +568,21 @@ func SendForget(
) util.JSONResponse { ) util.JSONResponse {
ctx := req.Context() ctx := req.Context()
logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID)
deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
var membershipRes roomserverAPI.QueryMembershipForUserResponse var membershipRes roomserverAPI.QueryMembershipForUserResponse
membershipReq := roomserverAPI.QueryMembershipForUserRequest{ membershipReq := roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, UserID: *deviceUserID,
} }
err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes)
if err != nil { if err != nil {
logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -47,7 +47,22 @@ func SendRedaction(
txnID *string, txnID *string,
txnCache *transactions.Cache, txnCache *transactions.Cache,
) util.JSONResponse { ) util.JSONResponse {
resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) deviceUserID, userIDErr := spec.NewUserID(device.UserID, true)
if userIDErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if queryErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
@ -73,25 +88,10 @@ func SendRedaction(
} }
} }
fullUserID, userIDErr := spec.NewUserID(device.UserID, true)
if userIDErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if queryErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
// "Users may redact their own events, and any user with a power level greater than or equal // "Users may redact their own events, and any user with a power level greater than or equal
// to the redact power level of the room may redact events there" // to the redact power level of the room may redact events there"
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey allowedToRedact := ev.SenderID() == senderID
if !allowedToRedact { if !allowedToRedact {
plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,

View file

@ -43,8 +43,16 @@ func SendTyping(
} }
} }
deviceUserID, err := spec.NewUserID(userID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
// Verify that the user is a member of this room // Verify that the user is a member of this room
resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID) resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -52,6 +52,7 @@ type sendServerNoticeRequest struct {
StateKey string `json:"state_key,omitempty"` StateKey string `json:"state_key,omitempty"`
} }
// nolint:gocyclo
// SendServerNotice sends a message to a specific user. It can only be invoked by an admin. // SendServerNotice sends a message to a specific user. It can only be invoked by an admin.
func SendServerNotice( func SendServerNotice(
req *http.Request, req *http.Request,
@ -187,9 +188,17 @@ func SendServerNotice(
} }
} else { } else {
// we've found a room in common, check the membership // we've found a room in common, check the membership
deviceUserID, err := spec.NewUserID(r.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
roomID = commonRooms[0] roomID = commonRooms[0]
membershipRes := api.QueryMembershipForUserResponse{} membershipRes := api.QueryMembershipForUserResponse{}
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") util.GetLogger(ctx).WithError(err).Error("unable to query membership for user")
return util.JSONResponse{ return util.JSONResponse{
@ -234,7 +243,7 @@ func SendServerNotice(
ctx, rsAPI, ctx, rsAPI,
api.KindNew, api.KindNew,
[]*types.HeaderedEvent{ []*types.HeaderedEvent{
&types.HeaderedEvent{PDU: e}, {PDU: e},
}, },
device.UserDomain(), device.UserDomain(),
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,

View file

@ -99,9 +99,17 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if !worldReadable { if !worldReadable {
// The room isn't world-readable so try to work out based on the // The room isn't world-readable so try to work out based on the
// user's membership if we want the latest state or not. // user's membership if we want the latest state or not.
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("UserID is invalid")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("Device UserID is invalid"),
}
}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, UserID: *userID,
}, &membershipRes) }, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
@ -140,14 +148,11 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
// use the result of the previous QueryLatestEventsAndState response // use the result of the previous QueryLatestEventsAndState response
// to find the state event, if provided. // to find the state event, if provided.
for _, ev := range stateRes.StateEvents { for _, ev := range stateRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, ev),
) )
} }
} else { } else {
@ -172,9 +177,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := ev.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk),
) )
} }
} }
@ -259,11 +273,19 @@ func OnIncomingStateTypeRequest(
// membershipRes will only be populated if the room is not world-readable. // membershipRes will only be populated if the room is not world-readable.
var membershipRes api.QueryMembershipForUserResponse var membershipRes api.QueryMembershipForUserResponse
if !worldReadable { if !worldReadable {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("UserID is invalid")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("Device UserID is invalid"),
}
}
// The room isn't world-readable so try to work out based on the // The room isn't world-readable so try to work out based on the
// user's membership if we want the latest state or not. // user's membership if we want the latest state or not.
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, UserID: *userID,
}, &membershipRes) }, &membershipRes)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
@ -344,13 +366,10 @@ func OnIncomingStateTypeRequest(
} }
} }
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, event),
} }
var res interface{} var res interface{}

View file

@ -59,7 +59,15 @@ func UpgradeRoom(
} }
} }
newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, device.UserID, gomatrixserverlib.RoomVersion(r.NewVersion)) userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("device UserID is invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, *userID, gomatrixserverlib.RoomVersion(r.NewVersion))
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
case roomserverAPI.ErrNotAllowed: case roomserverAPI.ErrNotAllowed:

View file

@ -45,7 +45,7 @@ func GetEventAuth(
if event.RoomID() != roomID { if event.RoomID() != roomID {
return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
} }
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -35,10 +35,6 @@ func GetEvent(
eventID string, eventID string,
origin spec.ServerName, origin spec.ServerName,
) util.JSONResponse { ) util.JSONResponse {
err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID)
if err != nil {
return *err
}
// /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string, // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string,
// which results in `QueryEventsByID` to first get the event and use that to determine the roomID. // which results in `QueryEventsByID` to first get the event and use that to determine the roomID.
event, err := fetchEvent(ctx, rsAPI, "", eventID) event, err := fetchEvent(ctx, rsAPI, "", eventID)
@ -46,6 +42,11 @@ func GetEvent(
return *err return *err
} }
err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
if err != nil {
return *err
}
return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{ return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{
Origin: origin, Origin: origin,
OriginServerTS: spec.AsTimestamp(time.Now()), OriginServerTS: spec.AsTimestamp(time.Now()),
@ -62,8 +63,9 @@ func allowedToSeeEvent(
origin spec.ServerName, origin spec.ServerName,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
eventID string, eventID string,
roomID string,
) *util.JSONResponse { ) *util.JSONResponse {
allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID) allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID, roomID)
if err != nil { if err != nil {
resErr := util.ErrorResponse(err) resErr := util.ErrorResponse(err)
return &resErr return &resErr

View file

@ -116,7 +116,7 @@ func getState(
if event.RoomID() != roomID { if event.RoomID() != roomID {
return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
} }
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
if resErr != nil { if resErr != nil {
return nil, nil, resErr return nil, nil, resErr
} }

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16

4
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8= github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=

View file

@ -34,11 +34,11 @@ func (e ErrNotAllowed) Error() string {
type RestrictedJoinAPI interface { type RestrictedJoinAPI interface {
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error)
RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error)
QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, senderID spec.SenderID) (bool, error)
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
} }
@ -191,7 +191,7 @@ type ClientRoomserverAPI interface {
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
// PerformRoomUpgrade upgrades a room to a newer version // PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, roomID string) error PerformAdminPurgeRoom(ctx context.Context, roomID string) error
@ -228,6 +228,7 @@ type FederationRoomserverAPI interface {
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
@ -238,15 +239,13 @@ type FederationRoomserverAPI interface {
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
// the state and auth chain to return. // the state and auth chain to return.
QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error
// Query if we think we're still in a room.
QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error
// Query missing events for a room from roomserver // Query missing events for a room from roomserver
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
@ -254,12 +253,6 @@ type FederationRoomserverAPI interface {
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error)
QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error)
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error)
StateQuerier() gomatrixserverlib.StateQuerier StateQuerier() gomatrixserverlib.StateQuerier
} }

View file

@ -215,8 +215,10 @@ type OutputNewInviteEvent struct {
type OutputRetireInviteEvent struct { type OutputRetireInviteEvent struct {
// The ID of the "m.room.member" invite event. // The ID of the "m.room.member" invite event.
EventID string EventID string
// The target user ID of the "m.room.member" invite event that was retired. // The room ID of the "m.room.member" invite event.
TargetUserID string RoomID string
// The target sender ID of the "m.room.member" invite event that was retired.
TargetSenderID spec.SenderID
// Optional event ID of the event that replaced the invite. // Optional event ID of the event that replaced the invite.
// This can be empty if the invite was rejected locally and we were unable // This can be empty if the invite was rejected locally and we were unable
// to reach the server that originally sent the invite. // to reach the server that originally sent the invite.

View file

@ -41,8 +41,8 @@ type PerformJoinRequest struct {
} }
type PerformLeaveRequest struct { type PerformLeaveRequest struct {
RoomID string `json:"room_id"` RoomID string
UserID string `json:"user_id"` Leaver spec.UserID
} }
type PerformLeaveResponse struct { type PerformLeaveResponse struct {

View file

@ -113,9 +113,9 @@ type QueryEventsByIDResponse struct {
// QueryMembershipForUserRequest is a request to QueryMembership // QueryMembershipForUserRequest is a request to QueryMembership
type QueryMembershipForUserRequest struct { type QueryMembershipForUserRequest struct {
// ID of the room to fetch membership from // ID of the room to fetch membership from
RoomID string `json:"room_id"` RoomID string
// ID of the user for whom membership is requested // ID of the user for whom membership is requested
UserID string `json:"user_id"` UserID spec.UserID
} }
// QueryMembershipForUserResponse is a response to QueryMembership // QueryMembershipForUserResponse is a response to QueryMembership
@ -145,7 +145,7 @@ type QueryMembershipsForRoomRequest struct {
// Optional - ID of the user sending the request, for checking if the // Optional - ID of the user sending the request, for checking if the
// user is allowed to see the memberships. If not specified then all // user is allowed to see the memberships. If not specified then all
// room memberships will be returned. // room memberships will be returned.
Sender string `json:"sender"` SenderID spec.SenderID `json:"sender"`
} }
// QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom // QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom
@ -448,11 +448,11 @@ func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.Ro
return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey)
} }
func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
return rq.Roomserver.InvitePending(ctx, roomID, userID) return rq.Roomserver.InvitePending(ctx, roomID, senderID)
} }
func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID) roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() { if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err return nil, err
@ -468,7 +468,7 @@ func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID sp
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
} }
userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err) return nil, fmt.Errorf("InternalServerError: %w", err)
@ -492,12 +492,8 @@ type MembershipQuerier struct {
} }
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
req := QueryMembershipForUserRequest{
RoomID: roomID.String(),
UserID: string(senderID),
}
res := QueryMembershipForUserResponse{} res := QueryMembershipForUserResponse{}
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) err := mq.Roomserver.QueryMembershipForSenderID(ctx, roomID, senderID, &res)
membership := "" membership := ""
if err == nil { if err == nil {

View file

@ -13,6 +13,9 @@
package auth package auth
import ( import (
"context"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
@ -22,6 +25,7 @@ import (
// IsServerAllowed returns true if the server is allowed to see events in the room // IsServerAllowed returns true if the server is allowed to see events in the room
// at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87
func IsServerAllowed( func IsServerAllowed(
ctx context.Context, db storage.RoomDatabase,
serverName spec.ServerName, serverName spec.ServerName,
serverCurrentlyInRoom bool, serverCurrentlyInRoom bool,
authEvents []gomatrixserverlib.PDU, authEvents []gomatrixserverlib.PDU,
@ -37,7 +41,7 @@ func IsServerAllowed(
return true return true
} }
// 2. If the user's membership was join, allow. // 2. If the user's membership was join, allow.
joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Join) joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join)
if joinedUserExists { if joinedUserExists {
return true return true
} }
@ -46,7 +50,7 @@ func IsServerAllowed(
return true return true
} }
// 4. If the user's membership was invite, and the history_visibility was set to invited, allow. // 4. If the user's membership was invite, and the history_visibility was set to invited, allow.
invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Invite) invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite)
if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited {
return true return true
} }
@ -70,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver
return visibility return visibility
} }
func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
for _, ev := range authEvents { for _, ev := range authEvents {
if ev.Type() != spec.MRoomMember { if ev.Type() != spec.MRoomMember {
continue continue
@ -85,12 +89,12 @@ func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []go
continue continue
} }
_, domain, err := gomatrixserverlib.SplitID('@', *stateKey) userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
if err != nil { if err != nil {
continue continue
} }
if domain == serverName { if userID.Domain() == serverName {
return true return true
} }
} }

View file

@ -1,13 +1,23 @@
package auth package auth
import ( import (
"context"
"testing" "testing"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
type FakeStorageDB struct {
storage.RoomDatabase
}
func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func TestIsServerAllowed(t *testing.T) { func TestIsServerAllowed(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
@ -77,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) {
authEvents = append(authEvents, ev.PDU) authEvents = append(authEvents, ev.PDU)
} }
if got := IsServerAllowed(tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want) t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
} }
}) })

View file

@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"sort" "sort"
"strings"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -55,9 +54,10 @@ func UpdateToInviteMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(),
Membership: spec.Join, Membership: spec.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetSenderID: spec.SenderID(*add.StateKey()),
}, },
}) })
} }
@ -94,13 +94,13 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
for i := range events { for i := range events {
gmslEvents[i] = events[i].PDU gmslEvents[i] = events[i].PDU
} }
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, spec.Join), nil return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil
} }
func IsInvitePending( func IsInvitePending(
ctx context.Context, db storage.Database, ctx context.Context, db storage.Database,
roomID, userID string, roomID string, senderID spec.SenderID,
) (bool, string, string, gomatrixserverlib.PDU, error) { ) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) {
// Look up the room NID for the supplied room ID. // Look up the room NID for the supplied room ID.
info, err := db.RoomInfo(ctx, roomID) info, err := db.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
@ -111,13 +111,13 @@ func IsInvitePending(
} }
// Look up the state key NID for the supplied user ID. // Look up the state key NID for the supplied user ID.
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil { if err != nil {
return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
} }
targetUserNID, targetUserFound := targetUserNIDs[userID] targetUserNID, targetUserFound := targetUserNIDs[string(senderID)]
if !targetUserFound { if !targetUserFound {
return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs)
} }
// Let's see if we have an event active for the user in the room. If // Let's see if we have an event active for the user in the room. If
@ -156,7 +156,7 @@ func IsInvitePending(
event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false) event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false)
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err
} }
// GetMembershipsAtState filters the state events to // GetMembershipsAtState filters the state events to
@ -264,7 +264,7 @@ func LoadStateEvents(
} }
func CheckServerAllowedToSeeEvent( func CheckServerAllowedToSeeEvent(
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, isServerInRoom bool, ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
) (bool, error) { ) (bool, error) {
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
switch err { switch err {
@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
case tables.OptimisationNotSupportedError: case tables.OptimisationNotSupportedError:
// The database engine didn't support this optimisation, so fall back to using // The database engine didn't support this optimisation, so fall back to using
// the old and slow method // the old and slow method
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName) stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -288,11 +288,11 @@ func CheckServerAllowedToSeeEvent(
return false, err return false, err
} }
} }
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil
} }
func slowGetHistoryVisibilityState( func slowGetHistoryVisibilityState(
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
) ([]gomatrixserverlib.PDU, error) { ) ([]gomatrixserverlib.PDU, error) {
roomState := state.NewStateResolution(db, info) roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
@ -319,10 +319,15 @@ func slowGetHistoryVisibilityState(
// then we'll filter it out. This does preserve state keys that // then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc. // are "" since these will contain history visibility etc.
for nid, key := range stateKeys { for nid, key := range stateKeys {
if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { if key != "" {
userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
if err == nil && userID != nil {
if userID.Domain() != serverName {
delete(stateKeys, nid) delete(stateKeys, nid)
} }
} }
}
}
// Now filter through all of the state events for the room. // Now filter through all of the state events for the room.
// If the state key NID appears in the list of valid state // If the state key NID appears in the list of valid state
@ -410,7 +415,7 @@ BFSLoop:
// hasn't been seen before. // hasn't been seen before.
if !visited[pre] { if !visited[pre] {
visited[pre] = true visited[pre] = true
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
if err != nil { if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event", "Error checking if allowed to see event",

View file

@ -8,6 +8,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -58,12 +59,12 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
} }
// Alice should have no pending invites and should have a NID // Alice should have no pending invites and should have a NID
pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, spec.SenderID(alice.ID))
assert.NoError(t, err, "failed to get pending invites") assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite") assert.False(t, pendingInvite, "unexpected pending invite")
// Bob should have no pending invites and receive a new NID // Bob should have no pending invites and receive a new NID
pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, spec.SenderID(bob.ID))
assert.NoError(t, err, "failed to get pending invites") assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite") assert.False(t, pendingInvite, "unexpected pending invite")
}) })

View file

@ -842,17 +842,15 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
continue continue
} }
// TODO: pseudoIDs: get userID for room using state key (which is now senderID) memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey()))
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
if err != nil { if err != nil {
continue continue
} }
// TODO: pseudoIDs: query account by state key (which is now senderID)
accountRes := &userAPI.QueryAccountByLocalpartResponse{} accountRes := &userAPI.QueryAccountByLocalpartResponse{}
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
Localpart: localpart, Localpart: memberUserID.Local(),
ServerName: senderDomain, ServerName: memberUserID.Domain(),
}, accountRes); err != nil { }, accountRes); err != nil {
return err return err
} }
@ -896,8 +894,8 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
inputEvents = append(inputEvents, api.InputRoomEvent{ inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: senderDomain, Origin: memberUserID.Domain(),
SendAsServer: string(senderDomain), SendAsServer: string(memberUserID.Domain()),
}) })
prevEvents = []string{event.EventID()} prevEvents = []string{event.EventID()}
} }

View file

@ -18,7 +18,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -72,7 +71,7 @@ func (r *Inputer) updateMemberships(
if change.addedEventNID != 0 { if change.addedEventNID != 0 {
ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID) ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID)
} }
if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { if updates, err = r.updateMembership(ctx, updater, targetUserNID, re, ae, updates); err != nil {
return nil, err return nil, err
} }
} }
@ -80,6 +79,7 @@ func (r *Inputer) updateMemberships(
} }
func (r *Inputer) updateMembership( func (r *Inputer) updateMembership(
ctx context.Context,
updater *shared.RoomUpdater, updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
remove, add *types.Event, remove, add *types.Event,
@ -97,7 +97,7 @@ func (r *Inputer) updateMembership(
var targetLocal bool var targetLocal bool
if add != nil { if add != nil {
targetLocal = r.isLocalTarget(add) targetLocal = r.isLocalTarget(ctx, add)
} }
mu, err := updater.MembershipUpdater(targetUserNID, targetLocal) mu, err := updater.MembershipUpdater(targetUserNID, targetLocal)
@ -136,11 +136,14 @@ func (r *Inputer) updateMembership(
} }
} }
func (r *Inputer) isLocalTarget(event *types.Event) bool { func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
isTargetLocalUser := false isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil { if statekey := event.StateKey(); statekey != nil {
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey) userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey))
isTargetLocalUser = domain == r.ServerName if err != nil || userID == nil {
return isTargetLocalUser
}
isTargetLocalUser = userID.Domain() == r.ServerName
} }
return isTargetLocalUser return isTargetLocalUser
} }
@ -161,9 +164,10 @@ func updateToJoinMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(),
Membership: spec.Join, Membership: spec.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetSenderID: spec.SenderID(*add.StateKey()),
}, },
}) })
} }
@ -187,9 +191,10 @@ func updateToLeaveMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(),
Membership: newMembership, Membership: newMembership,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetSenderID: spec.SenderID(*add.StateKey()),
}, },
}) })
} }

View file

@ -149,11 +149,11 @@ func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context, ctx context.Context,
userID string, userID string,
) (affected []string, err error) { ) (affected []string, err error) {
_, domain, err := gomatrixserverlib.SplitID('@', userID) fullUserID, err := spec.NewUserID(userID, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(fullUserID.Domain()) {
return nil, fmt.Errorf("can only evacuate local users using this endpoint") return nil, fmt.Errorf("can only evacuate local users using this endpoint")
} }
@ -172,7 +172,7 @@ func (r *Admin) PerformAdminEvacuateUser(
for _, roomID := range allRooms { for _, roomID := range allRooms {
leaveReq := &api.PerformLeaveRequest{ leaveReq := &api.PerformLeaveRequest{
RoomID: roomID, RoomID: roomID,
UserID: userID, Leaver: *fullUserID,
} }
leaveRes := &api.PerformLeaveResponse{} leaveRes := &api.PerformLeaveResponse{}
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)

View file

@ -582,7 +582,7 @@ func joinEventsFromHistoryVisibility(
} }
// Can we see events in the room? // Can we see events in the room?
canSeeEvents := auth.IsServerAllowed(thisServer, true, events) canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events)
visibility := auth.HistoryVisibilityForRoom(events) visibility := auth.HistoryVisibilityForRoom(events)
if !canSeeEvents { if !canSeeEvents {
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)

View file

@ -63,9 +63,17 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
} }
createContent["creator"] = userID.String() senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
createContent["creator"] = senderID
createContent["room_version"] = createRequest.RoomVersion createContent["room_version"] = createRequest.RoomVersion
powerLevelContent := eventutil.InitialPowerLevelsContent(userID.String()) powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID))
joinRuleContent := gomatrixserverlib.JoinRuleContent{ joinRuleContent := gomatrixserverlib.JoinRuleContent{
JoinRule: spec.Invite, JoinRule: spec.Invite,
} }
@ -121,7 +129,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
membershipEvent := gomatrixserverlib.FledglingEvent{ membershipEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: userID.String(), StateKey: string(senderID),
Content: gomatrixserverlib.MemberContent{ Content: gomatrixserverlib.MemberContent{
Membership: spec.Join, Membership: spec.Join,
DisplayName: createRequest.UserDisplayName, DisplayName: createRequest.UserDisplayName,
@ -270,7 +278,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
var builtEvents []*types.HeaderedEvent var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed") util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{

View file

@ -134,12 +134,12 @@ func (r *Inviter) PerformInvite(
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
} }
if event.StateKey() == nil { if event.StateKey() == nil || *event.StateKey() == "" {
return fmt.Errorf("invite must be a state event") return fmt.Errorf("invite must be a state event")
} }
invitedUser, err := spec.NewUserID(*event.StateKey(), true) invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err != nil { if err != nil || invitedUser == nil {
return spec.InvalidParam("The user ID is invalid") return spec.InvalidParam("Could not find the matching senderID for this user")
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())

View file

@ -162,7 +162,7 @@ func (r *Joiner) performJoinRoomByID(
} }
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) roomID, err := spec.NewRoomID(req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)} return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
} }
@ -170,8 +170,8 @@ func (r *Joiner) performJoinRoomByID(
// If the server name in the room ID isn't ours then it's a // If the server name in the room ID isn't ours then it's a
// possible candidate for finding the room via federation. Add // possible candidate for finding the room via federation. Add
// it to the list of servers to try. // it to the list of servers to try.
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
req.ServerNames = append(req.ServerNames, domain) req.ServerNames = append(req.ServerNames, roomID.Domain())
} }
// Prepare the template for the join event. // Prepare the template for the join event.
@ -203,7 +203,7 @@ func (r *Joiner) performJoinRoomByID(
req.Content = map[string]interface{}{} req.Content = map[string]interface{}{}
} }
req.Content["membership"] = spec.Join req.Content["membership"] = spec.Join
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil { if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
return "", "", aerr return "", "", aerr
} else if authorisedVia != "" { } else if authorisedVia != "" {
req.Content["join_authorised_via_users_server"] = authorisedVia req.Content["join_authorised_via_users_server"] = authorisedVia
@ -226,17 +226,17 @@ func (r *Joiner) performJoinRoomByID(
// Force a federated join if we're dealing with a pending invite // Force a federated join if we're dealing with a pending invite
// and we aren't in the room. // and we aren't in the room.
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
if err == nil && !serverInRoom && isInvitePending { if err == nil && !serverInRoom && isInvitePending {
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender)
if ierr != nil { if queryErr != nil {
return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
} }
// If we were invited by someone from another server then we can // If we were invited by someone from another server then we can
// assume they are in the room so we can join via them. // assume they are in the room so we can join via them.
if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
req.ServerNames = append(req.ServerNames, inviterDomain) req.ServerNames = append(req.ServerNames, inviter.Domain())
forceFederatedJoin = true forceFederatedJoin = true
memberEvent := gjson.Parse(string(inviteEvent.JSON())) memberEvent := gjson.Parse(string(inviteEvent.JSON()))
// only set unsigned if we've got a content.membership, which we _should_ // only set unsigned if we've got a content.membership, which we _should_
@ -298,12 +298,8 @@ func (r *Joiner) performJoinRoomByID(
// a member of the room. This is best-effort (as in we won't // a member of the room. This is best-effort (as in we won't
// fail if we can't find the existing membership) because there // fail if we can't find the existing membership) because there
// is really no harm in just sending another membership event. // is really no harm in just sending another membership event.
membershipReq := &api.QueryMembershipForUserRequest{
RoomID: req.RoomIDOrAlias,
UserID: userID.String(),
}
membershipRes := &api.QueryMembershipForUserResponse{} membershipRes := &api.QueryMembershipForUserResponse{}
_ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) _ = r.Queryer.QueryMembershipForSenderID(ctx, *roomID, senderID, membershipRes)
// If we haven't already joined the room then send an event // If we haven't already joined the room then send an event
// into the room changing our membership status. // into the room changing our membership status.
@ -328,7 +324,7 @@ func (r *Joiner) performJoinRoomByID(
// The room doesn't exist locally. If the room ID looks like it should // The room doesn't exist locally. If the room ID looks like it should
// be ours then this probably means that we've nuked our database at // be ours then this probably means that we've nuked our database at
// some point. // some point.
if r.Cfg.Matrix.IsLocalServerName(domain) { if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
// If there are no more server names to try then give up here. // If there are no more server names to try then give up here.
// Otherwise we'll try a federated join as normal, since it's quite // Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers. // possible that the room still exists on other servers.
@ -376,15 +372,12 @@ func (r *Joiner) performFederatedJoinRoomByID(
func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
ctx context.Context, ctx context.Context,
joinReq *rsAPI.PerformJoinRequest, joinReq *rsAPI.PerformJoinRequest,
senderID spec.SenderID,
) (string, error) { ) (string, error) {
roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias) roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias)
if err != nil { if err != nil {
return "", err return "", err
} }
userID, err := spec.NewUserID(joinReq.UserID, true)
if err != nil {
return "", err
}
return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID) return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, senderID)
} }

View file

@ -53,16 +53,12 @@ func (r *Leaver) PerformLeave(
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, res *api.PerformLeaveResponse,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) {
if err != nil { return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)
}
if !r.Cfg.Matrix.IsLocalServerName(domain) {
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID)
} }
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomID, "room_id": req.RoomID,
"user_id": req.UserID, "user_id": req.Leaver.String(),
}) })
logger.Info("User requested to leave join") logger.Info("User requested to leave join")
if strings.HasPrefix(req.RoomID, "!") { if strings.HasPrefix(req.RoomID, "!") {
@ -82,21 +78,26 @@ func (r *Leaver) performLeaveRoomByID(
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver)
if err != nil {
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
}
// If there's an invite outstanding for the room then respond to // If there's an invite outstanding for the room then respond to
// that. // that.
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
if err == nil && isInvitePending { if err == nil && isInvitePending {
_, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser) sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
if serr != nil { if serr != nil || sender == nil {
return nil, fmt.Errorf("sender %q is invalid", senderUser) return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
} }
if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, leaver)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{} accData := &userapi.QueryAccountDataResponse{}
if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
UserID: req.UserID, UserID: req.Leaver.String(),
RoomID: req.RoomID, RoomID: req.RoomID,
DataType: "m.tag", DataType: "m.tag",
}, accData); err != nil { }, accData); err != nil {
@ -127,7 +128,7 @@ func (r *Leaver) performLeaveRoomByID(
StateToFetch: []gomatrixserverlib.StateKeyTuple{ StateToFetch: []gomatrixserverlib.StateKeyTuple{
{ {
EventType: spec.MRoomMember, EventType: spec.MRoomMember,
StateKey: req.UserID, StateKey: string(leaver),
}, },
}, },
} }
@ -141,26 +142,18 @@ func (r *Leaver) performLeaveRoomByID(
// Now let's see if the user is in the room. // Now let's see if the user is in the room.
if len(latestRes.StateEvents) == 0 { if len(latestRes.StateEvents) == 0 {
return nil, fmt.Errorf("user %q is not a member of room %q", req.UserID, req.RoomID) return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
} }
membership, err := latestRes.StateEvents[0].Membership() membership, err := latestRes.StateEvents[0].Membership()
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting membership: %w", err) return nil, fmt.Errorf("error getting membership: %w", err)
} }
if membership != spec.Join && membership != spec.Invite { if membership != spec.Join && membership != spec.Invite {
return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership) return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
} }
// Prepare the template for the leave event. // Prepare the template for the leave event.
fullUserID, err := spec.NewUserID(req.UserID, true) senderIDString := string(leaver)
if err != nil {
return nil, err
}
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID)
if err != nil {
return nil, err
}
senderIDString := string(senderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
SenderID: senderIDString, SenderID: senderIDString,
@ -175,16 +168,13 @@ func (r *Leaver) performLeaveRoomByID(
return nil, fmt.Errorf("eb.SetUnsigned: %w", err) return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
} }
// Get the sender domain.
senderDomain := fullUserID.Domain()
// We know that the user is in the room at this point so let's build // We know that the user is in the room at this point so let's build
// a leave event. // a leave event.
// TODO: Check what happens if the room exists on the server // TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
var buildRes rsAPI.QueryLatestEventsAndStateResponse var buildRes rsAPI.QueryLatestEventsAndStateResponse
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.Domain())
if err != nil { if err != nil {
return nil, fmt.Errorf("SigningIdentityFor: %w", err) return nil, fmt.Errorf("SigningIdentityFor: %w", err)
} }
@ -201,8 +191,8 @@ func (r *Leaver) performLeaveRoomByID(
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: senderDomain, Origin: req.Leaver.Domain(),
SendAsServer: string(senderDomain), SendAsServer: string(req.Leaver.Domain()),
}, },
}, },
} }
@ -219,21 +209,17 @@ func (r *Leaver) performFederatedRejectInvite(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
senderUser, eventID string, inviteSender spec.UserID, eventID string,
leaver spec.SenderID,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
_, domain, err := gomatrixserverlib.SplitID('@', senderUser)
if err != nil {
return nil, fmt.Errorf("user ID %q invalid: %w", senderUser, err)
}
// Ask the federation sender to perform a federated leave for us. // Ask the federation sender to perform a federated leave for us.
leaveReq := fsAPI.PerformLeaveRequest{ leaveReq := fsAPI.PerformLeaveRequest{
RoomID: req.RoomID, RoomID: req.RoomID,
UserID: req.UserID, UserID: req.Leaver.String(),
ServerNames: []spec.ServerName{domain}, ServerNames: []spec.ServerName{inviteSender.Domain()},
} }
leaveRes := fsAPI.PerformLeaveResponse{} leaveRes := fsAPI.PerformLeaveResponse{}
if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
// failures in PerformLeave should NEVER stop us from telling other components like the // failures in PerformLeave should NEVER stop us from telling other components like the
// sync API that the invite was withdrawn. Otherwise we can end up with stuck invites. // sync API that the invite was withdrawn. Otherwise we can end up with stuck invites.
util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event")
@ -244,7 +230,7 @@ func (r *Leaver) performFederatedRejectInvite(
util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event")
} }
updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, req.UserID, true, info.RoomVersion) updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(leaver), true, info.RoomVersion)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event")
} }
@ -268,8 +254,9 @@ func (r *Leaver) performFederatedRejectInvite(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: req.RoomID,
Membership: "leave", Membership: "leave",
TargetUserID: req.UserID, TargetSenderID: leaver,
}, },
}, },
}, nil }, nil

View file

@ -38,19 +38,15 @@ type Upgrader struct {
// PerformRoomUpgrade upgrades a room from one version to another // PerformRoomUpgrade upgrades a room from one version to another
func (r *Upgrader) PerformRoomUpgrade( func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context, ctx context.Context,
roomID, userID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion,
) (newRoomID string, err error) { ) (newRoomID string, err error) {
return r.performRoomUpgrade(ctx, roomID, userID, roomVersion) return r.performRoomUpgrade(ctx, roomID, userID, roomVersion)
} }
func (r *Upgrader) performRoomUpgrade( func (r *Upgrader) performRoomUpgrade(
ctx context.Context, ctx context.Context,
roomID, userID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion,
) (string, error) { ) (string, error) {
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return "", api.ErrNotAllowed{Err: fmt.Errorf("error validating the user ID")}
}
evTime := time.Now() evTime := time.Now()
// Return an immediate error if the room does not exist // Return an immediate error if the room does not exist
@ -58,14 +54,20 @@ func (r *Upgrader) performRoomUpgrade(
return "", err return "", err
} }
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
return "", err
}
// 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone)
if !r.userIsAuthorized(ctx, userID, roomID) { if !r.userIsAuthorized(ctx, senderID, roomID) {
return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")} return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")}
} }
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs? // probably shouldn't be using pseudo-random strings, maybe GUIDs?
newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain())
// Get the existing room state for the old room. // Get the existing room state for the old room.
oldRoomReq := &api.QueryLatestEventsAndStateRequest{ oldRoomReq := &api.QueryLatestEventsAndStateRequest{
@ -77,25 +79,25 @@ func (r *Upgrader) performRoomUpgrade(
} }
// Make the tombstone event // Make the tombstone event
tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, userID, roomID, newRoomID) tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, senderID, userID.Domain(), roomID, newRoomID)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Generate the initial events we need to send into the new room. This includes copied state events and bans // Generate the initial events we need to send into the new room. This includes copied state events and bans
// as well as the power level events needed to set up the room // as well as the power level events needed to set up the room
eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, roomVersion, tombstoneEvent) eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, senderID, roomID, roomVersion, tombstoneEvent)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Send the setup events to the new room // Send the setup events to the new room
if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, roomVersion, eventsToMake); pErr != nil { if pErr = r.sendInitialEvents(ctx, evTime, senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil {
return "", pErr return "", pErr
} }
// 5. Send the tombstone event to the old room // 5. Send the tombstone event to the old room
if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil { if pErr = r.sendHeaderedEvent(ctx, userID.Domain(), tombstoneEvent, string(userID.Domain())); pErr != nil {
return "", pErr return "", pErr
} }
@ -105,17 +107,17 @@ func (r *Upgrader) performRoomUpgrade(
} }
// If the old room had a canonical alias event, it should be deleted in the old room // If the old room had a canonical alias event, it should be deleted in the old room
if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil { if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, senderID, userID.Domain(), roomID); pErr != nil {
return "", pErr return "", pErr
} }
// 4. Move local aliases to the new room // 4. Move local aliases to the new room
if pErr = moveLocalAliases(ctx, roomID, newRoomID, userID, r.URSAPI); pErr != nil { if pErr = moveLocalAliases(ctx, roomID, newRoomID, senderID, userID, r.URSAPI); pErr != nil {
return "", pErr return "", pErr
} }
// 6. Restrict power levels in the old room // 6. Restrict power levels in the old room
if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil { if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, senderID, userID.Domain(), roomID); pErr != nil {
return "", pErr return "", pErr
} }
@ -130,7 +132,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma
return oldPowerLevelsEvent.PowerLevels() return oldPowerLevelsEvent.PowerLevels()
} }
func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error { func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error {
restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID)
if pErr != nil { if pErr != nil {
return pErr return pErr
@ -147,7 +149,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel
restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel
restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{
Type: spec.MRoomPowerLevels, Type: spec.MRoomPowerLevels,
StateKey: "", StateKey: "",
Content: restrictedPowerLevelContent, Content: restrictedPowerLevelContent,
@ -165,7 +167,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
} }
func moveLocalAliases(ctx context.Context, func moveLocalAliases(ctx context.Context,
roomID, newRoomID, userID string, roomID, newRoomID string, senderID spec.SenderID, userID spec.UserID,
URSAPI api.RoomserverInternalAPI, URSAPI api.RoomserverInternalAPI,
) (err error) { ) (err error) {
@ -175,14 +177,6 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to get old room aliases: %w", err) return fmt.Errorf("Failed to get old room aliases: %w", err)
} }
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return fmt.Errorf("Failed to get userID: %w", err)
}
senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return fmt.Errorf("Failed to get senderID: %w", err)
}
for _, alias := range aliasRes.Aliases { for _, alias := range aliasRes.Aliases {
removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias} removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{} removeAliasRes := api.RemoveRoomAliasResponse{}
@ -190,7 +184,7 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to remove old room alias: %w", err) return fmt.Errorf("Failed to remove old room alias: %w", err)
} }
setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} setAliasReq := api.SetRoomAliasRequest{UserID: userID.String(), Alias: alias, RoomID: newRoomID}
setAliasRes := api.SetRoomAliasResponse{} setAliasRes := api.SetRoomAliasResponse{}
if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil {
return fmt.Errorf("Failed to set new room alias: %w", err) return fmt.Errorf("Failed to set new room alias: %w", err)
@ -199,7 +193,7 @@ func moveLocalAliases(ctx context.Context,
return nil return nil
} }
func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error { func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error {
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") { if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") {
continue continue
@ -217,7 +211,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
} }
} }
emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{
Type: spec.MRoomCanonicalAlias, Type: spec.MRoomCanonicalAlias,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
}) })
@ -280,7 +274,7 @@ func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error
return nil return nil
} }
func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, func (r *Upgrader) userIsAuthorized(ctx context.Context, senderID spec.SenderID, roomID string,
) bool { ) bool {
plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,
@ -295,26 +289,18 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
} }
// Check for power level required to send tombstone event (marks the current room as obsolete), // Check for power level required to send tombstone event (marks the current room as obsolete),
// if not found, use the StateDefault power level // if not found, use the StateDefault power level
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return false
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return false
}
return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true) return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true)
} }
// nolint:gocyclo // nolint:gocyclo
func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) { func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, senderID spec.SenderID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) {
state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents)) state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents))
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.StateKey() == nil { if event.StateKey() == nil {
// This shouldn't ever happen, but better to be safe than sorry. // This shouldn't ever happen, but better to be safe than sorry.
continue continue
} }
if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) { if event.Type() == spec.MRoomMember && !event.StateKeyEquals(string(senderID)) {
// With the exception of bans which we do want to copy, we // With the exception of bans which we do want to copy, we
// should ignore membership events that aren't our own, as event auth will // should ignore membership events that aren't our own, as event auth will
// prevent us from being able to create membership events on behalf of other // prevent us from being able to create membership events on behalf of other
@ -330,6 +316,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
} }
} }
// skip events that rely on a specific user being present // skip events that rely on a specific user being present
// TODO: What to do here for pseudoIDs? It's checking non-member events for state keys with userIDs.
sKey := *event.StateKey() sKey := *event.StateKey()
if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" { if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" {
continue continue
@ -341,7 +328,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// in the following section. // in the following section.
override := map[gomatrixserverlib.StateKeyTuple]struct{}{ override := map[gomatrixserverlib.StateKeyTuple]struct{}{
{EventType: spec.MRoomCreate, StateKey: ""}: {}, {EventType: spec.MRoomCreate, StateKey: ""}: {},
{EventType: spec.MRoomMember, StateKey: userID}: {}, {EventType: spec.MRoomMember, StateKey: string(senderID)}: {},
{EventType: spec.MRoomPowerLevels, StateKey: ""}: {}, {EventType: spec.MRoomPowerLevels, StateKey: ""}: {},
{EventType: spec.MRoomJoinRules, StateKey: ""}: {}, {EventType: spec.MRoomJoinRules, StateKey: ""}: {},
} }
@ -355,7 +342,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
} }
oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}] oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}]
oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: userID}] oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: string(senderID)}]
oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}] oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}]
oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}] oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}]
@ -364,7 +351,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// in the create event (such as for the room types MSC). // in the create event (such as for the room types MSC).
newCreateContent := map[string]interface{}{} newCreateContent := map[string]interface{}{}
_ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent) _ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent)
newCreateContent["creator"] = userID newCreateContent["creator"] = string(senderID)
newCreateContent["room_version"] = newVersion newCreateContent["room_version"] = newVersion
newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{ newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{
EventID: tombstoneEvent.EventID(), EventID: tombstoneEvent.EventID(),
@ -385,7 +372,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
newMembershipContent["membership"] = spec.Join newMembershipContent["membership"] = spec.Join
newMembershipEvent := gomatrixserverlib.FledglingEvent{ newMembershipEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: userID, StateKey: string(senderID),
Content: newMembershipContent, Content: newMembershipContent,
} }
@ -400,14 +387,6 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return nil, fmt.Errorf("Power level event content was invalid") return nil, fmt.Errorf("Power level event content was invalid")
} }
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID) tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID)
// Now do the join rules event, same as the create and membership // Now do the join rules event, same as the create and membership
@ -470,21 +449,13 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return eventsToMake, nil return eventsToMake, nil
} }
func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error {
var err error var err error
var builtEvents []*types.HeaderedEvent var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
for i, e := range eventsToMake { for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
fullUserID, userIDErr := spec.NewUserID(userID, true)
if userIDErr != nil {
return userIDErr
}
senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID)
if queryErr != nil {
return queryErr
}
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(senderID),
RoomID: newRoomID, RoomID: newRoomID,
@ -549,7 +520,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
func (r *Upgrader) makeTombstoneEvent( func (r *Upgrader) makeTombstoneEvent(
ctx context.Context, ctx context.Context,
evTime time.Time, evTime time.Time,
userID, roomID, newRoomID string, senderID spec.SenderID, senderDomain spec.ServerName, roomID, newRoomID string,
) (*types.HeaderedEvent, error) { ) (*types.HeaderedEvent, error) {
content := map[string]interface{}{ content := map[string]interface{}{
"body": "This room has been replaced", "body": "This room has been replaced",
@ -559,30 +530,21 @@ func (r *Upgrader) makeTombstoneEvent(
Type: "m.room.tombstone", Type: "m.room.tombstone",
Content: content, Content: content,
} }
return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event) return r.makeHeaderedEvent(ctx, evTime, senderID, senderDomain, roomID, event)
} }
func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, senderID spec.SenderID, senderDomain spec.ServerName, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) {
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Type: event.Type, Type: event.Type,
StateKey: &event.StateKey, StateKey: &event.StateKey,
} }
err = proto.SetContent(event.Content) err := proto.SetContent(event.Content)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err) return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err)
} }
// Get the sender domain. // Get the sender domain.
senderDomain := fullUserID.Domain()
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err)

View file

@ -48,7 +48,7 @@ type Queryer struct {
Cfg *config.Dendrite Cfg *config.Dendrite
} }
func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := r.QueryRoomInfo(ctx, roomID) roomInfo, err := r.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() { if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err return nil, err
@ -64,7 +64,7 @@ func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
} }
userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err) return nil, fmt.Errorf("InternalServerError: %w", err)
@ -220,13 +220,14 @@ func (r *Queryer) QueryEventsByID(
return nil return nil
} }
// QueryMembershipForUser implements api.RoomserverInternalAPI // QueryMembershipForSenderID implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser( func (r *Queryer) QueryMembershipForSenderID(
ctx context.Context, ctx context.Context,
request *api.QueryMembershipForUserRequest, roomID spec.RoomID,
senderID spec.SenderID,
response *api.QueryMembershipForUserResponse, response *api.QueryMembershipForUserResponse,
) error { ) error {
info, err := r.DB.RoomInfo(ctx, request.RoomID) info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil { if err != nil {
return err return err
} }
@ -236,7 +237,7 @@ func (r *Queryer) QueryMembershipForUser(
} }
response.RoomExists = true response.RoomExists = true
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID)
if err != nil { if err != nil {
return err return err
} }
@ -264,6 +265,24 @@ func (r *Queryer) QueryMembershipForUser(
return err return err
} }
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
if err != nil {
return err
}
roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
}
// QueryMembershipAtEvent returns the known memberships at a given event. // QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned // If the state before an event is not known, an empty list will be returned
// for that event instead. // for that event instead.
@ -373,7 +392,7 @@ func (r *Queryer) QueryMembershipsForRoom(
// If no sender is specified then we will just return the entire // If no sender is specified then we will just return the entire
// set of memberships for the room, regardless of whether a specific // set of memberships for the room, regardless of whether a specific
// user is allowed to see them or not. // user is allowed to see them or not.
if request.Sender == "" { if request.SenderID == "" {
var events []types.Event var events []types.Event
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly)
@ -388,18 +407,15 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
for _, event := range events { for _, event := range events {
sender := spec.UserID{} clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) return r.QueryUserIDForSender(ctx, roomID, senderID)
if queryErr == nil && userID != nil { }, event)
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }
return nil return nil
} }
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID)
if err != nil { if err != nil {
return err return err
} }
@ -442,12 +458,9 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
for _, event := range events { for _, event := range events {
sender := spec.UserID{} clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) return r.QueryUserIDForSender(ctx, roomID, senderID)
if err == nil && userID != nil { }, event)
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }
@ -489,6 +502,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context, ctx context.Context,
serverName spec.ServerName, serverName spec.ServerName,
eventID string, eventID string,
roomID string,
) (allowed bool, err error) { ) (allowed bool, err error) {
events, err := r.DB.EventNIDs(ctx, []string{eventID}) events, err := r.DB.EventNIDs(ctx, []string{eventID})
if err != nil { if err != nil {
@ -518,7 +532,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
} }
return helpers.CheckServerAllowedToSeeEvent( return helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, info, eventID, serverName, isInRoom, ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
) )
} }
@ -909,8 +923,8 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
return nil return nil
} }
func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String()) pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID)
return pending, err return pending, err
} }
@ -926,8 +940,8 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve
return res, err return res, err
} }
func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) { func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) {
_, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String()) _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID)
return isIn, err return isIn, err
} }
@ -957,7 +971,7 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
} }
// nolint:gocyclo // nolint:gocyclo
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
// Look up if we know anything about the room. If it doesn't exist // Look up if we know anything about the room. If it doesn't exist
// or is a stub entry then we can't do anything. // or is a stub entry then we can't do anything.
roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) roomInfo, err := r.DB.RoomInfo(ctx, roomID.String())
@ -972,7 +986,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return "", err return "", err
} }
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
} }
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {

View file

@ -722,7 +722,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
roomID, _ := spec.NewRoomID(testRoom.ID) roomID, _ := spec.NewRoomID(testRoom.ID)
userID, _ := spec.NewUserID(bob.ID, true) userID, _ := spec.NewUserID(bob.ID, true)
got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID) got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, spec.SenderID(userID.String()))
if tc.wantError && err == nil { if tc.wantError && err == nil {
t.Fatal("expected error, got none") t.Fatal("expected error, got none")
} }
@ -821,17 +821,6 @@ func TestUpgrade(t *testing.T) {
validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI)
wantNewRoom bool wantNewRoom bool
}{ }{
{
name: "invalid userID",
upgradeUser: "!notvalid:test",
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
room := test.NewRoom(t, alice)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return room.ID
},
},
{ {
name: "invalid roomID", name: "invalid roomID",
upgradeUser: alice.ID, upgradeUser: alice.ID,
@ -1049,7 +1038,11 @@ func TestUpgrade(t *testing.T) {
} }
roomID := tc.roomFunc(rsAPI) roomID := tc.roomFunc(rsAPI)
newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, tc.upgradeUser, version.DefaultRoomVersion()) userID, err := spec.NewUserID(tc.upgradeUser, true)
if err != nil {
t.Fatalf("upgrade userID is invalid")
}
newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, *userID, version.DefaultRoomVersion())
if err != nil && tc.wantNewRoom { if err != nil && tc.wantNewRoom {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -131,7 +131,7 @@ type Database interface {
// in this room, along a boolean set to true if the user is still in this room, // in this room, along a boolean set to true if the user is still in this room,
// false if not. // false if not.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error)
// Lookup the membership event numeric IDs for all user that are or have // Lookup the membership event numeric IDs for all user that are or have
// been members of a given room. Only lookup events of "join" membership if // been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true. // joinOnly is set to true.

View file

@ -490,10 +490,10 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
}) })
} }
func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) {
var requestSenderUserNID types.EventStateKeyNID var requestSenderUserNID types.EventStateKeyNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID))
return err return err
}) })
if err != nil { if err != nil {
@ -936,6 +936,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) (
return roomVersion, err return roomVersion, err
} }
// nolint:gocyclo
// MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec: // MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec:
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
@ -1014,7 +1015,7 @@ func (d *EventDatabase) MaybeRedactEvent(
switch { switch {
case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact:
// 1. The power level of the redaction events sender is greater than or equal to the redact level. // 1. The power level of the redaction events sender is greater than or equal to the redact level.
case sender1Domain == sender2Domain: case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain:
// 2. The domain of the redaction events sender matches that of the original events sender. // 2. The domain of the redaction events sender matches that of the original events sender.
default: default:
ignoreRedaction = true ignoreRedaction = true

View file

@ -154,7 +154,7 @@ type reqCtx struct {
rsAPI roomserver.RoomserverInternalAPI rsAPI roomserver.RoomserverInternalAPI
db Database db Database
req *EventRelationshipRequest req *EventRelationshipRequest
userID string userID spec.UserID
roomVersion gomatrixserverlib.RoomVersion roomVersion gomatrixserverlib.RoomVersion
// federated request args // federated request args
@ -173,10 +173,17 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
} }
} }
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: 400,
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
}
}
rc := reqCtx{ rc := reqCtx{
ctx: req.Context(), ctx: req.Context(),
req: relation, req: relation,
userID: device.UserID, userID: *userID,
rsAPI: rsAPI, rsAPI: rsAPI,
fsAPI: fsAPI, fsAPI: fsAPI,
isFederatedRequest: false, isFederatedRequest: false,

View file

@ -529,6 +529,10 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
return spec.SenderID(userID.String()), nil
}
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
for _, eventID := range req.EventIDs { for _, eventID := range req.EventIDs {
ev := r.events[eventID] ev := r.events[eventID]
@ -540,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver
} }
func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error {
rooms := r.userToJoinedRooms[req.UserID] rooms := r.userToJoinedRooms[req.UserID.String()]
for _, roomID := range rooms { for _, roomID := range rooms {
if roomID == req.RoomID { if roomID == req.RoomID {
res.IsInRoom = true res.IsInRoom = true

View file

@ -373,7 +373,15 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
// TODO: check that it's a join and not a profile change (means unmarshalling prev_content) // TODO: check that it's a join and not a profile change (means unmarshalling prev_content)
if membership == spec.Join { if membership == spec.Join {
// check it's a local join // check it's a local join
if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil { if ev.StateKey() == nil {
return sp, fmt.Errorf("unexpected nil state_key")
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil || userID == nil {
return sp, fmt.Errorf("failed getting userID for sender: %w", err)
}
if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
return sp, nil return sp, nil
} }
@ -395,9 +403,15 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
if msg.Event.StateKey() == nil { if msg.Event.StateKey() == nil {
return return
} }
if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil {
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
if err != nil || userID == nil {
return return
} }
if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
return
}
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
@ -440,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos) s.inviteStream.Advance(pduPos)
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
if err != nil || userID == nil {
log.WithFields(log.Fields{
"event_id": msg.EventID,
"sender_id": msg.TargetSenderID,
log.ErrorKey: err,
}).Errorf("failed to find userID for sender")
return
}
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String())
} }
func (s *OutputRoomEventConsumer) onNewPeek( func (s *OutputRoomEventConsumer) onNewPeek(

View file

@ -134,10 +134,18 @@ func ApplyHistoryVisibilityFilter(
} }
} }
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules") // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) {
user, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user)
if err == nil {
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
eventsFiltered = append(eventsFiltered, ev) eventsFiltered = append(eventsFiltered, ev)
continue continue
} }
}
// Always allow history evVis events on boundaries. This is done // Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive // by setting the effective evVis to the least restrictive
// of the old vs new. // of the old vs new.

View file

@ -169,12 +169,16 @@ func TrackChangedUsers(
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
for _, state := range stateRes.Rooms { for roomID, state := range stateRes.Rooms {
for tuple, membership := range state { for tuple, membership := range state {
if membership != spec.Join { if membership != spec.Join {
continue continue
} }
queryRes.UserIDsToCount[tuple.StateKey]-- user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
if queryErr != nil || user == nil {
continue
}
queryRes.UserIDsToCount[user.String()]--
} }
} }
@ -211,14 +215,18 @@ func TrackChangedUsers(
if err != nil { if err != nil {
return nil, left, err return nil, left, err
} }
for _, state := range stateRes.Rooms { for roomID, state := range stateRes.Rooms {
for tuple, membership := range state { for tuple, membership := range state {
if membership != spec.Join { if membership != spec.Join {
continue continue
} }
// new user who we weren't previously sharing rooms with // new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
changed = append(changed, tuple.StateKey) // changed is returned user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
if err != nil || user == nil {
continue
}
changed = append(changed, user.String()) // changed is returned
} }
} }
} }

View file

@ -64,6 +64,10 @@ type mockRoomserverAPI struct {
roomIDToJoinedMembers map[string][]string roomIDToJoinedMembers map[string][]string
} }
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
// QueryRoomsForUser retrieves a list of room IDs matching the given query. // QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
return nil return nil

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types" rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -37,6 +38,7 @@ import (
// in missed events. // in missed events.
type Notifier struct { type Notifier struct {
lock *sync.RWMutex lock *sync.RWMutex
rsAPI api.SyncRoomserverAPI
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine // A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]*userIDSet roomIDToJoinedUsers map[string]*userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine // A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
@ -55,8 +57,9 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position. // NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and // In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier() *Notifier { func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier {
return &Notifier{ return &Notifier{
rsAPI: rsAPI,
roomIDToJoinedUsers: make(map[string]*userIDSet), roomIDToJoinedUsers: make(map[string]*userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream), userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
@ -104,7 +107,12 @@ func (n *Notifier) OnNewEvent(
peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list. // If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil { if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID := *ev.StateKey() targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: Failed to find the userID for this event",
)
} else {
membership, err := ev.Membership() membership, err := ev.Membership()
if err != nil { if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf( log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
@ -114,16 +122,17 @@ func (n *Notifier) OnNewEvent(
// Keep the joined user map up-to-date // Keep the joined user map up-to-date
switch membership { switch membership {
case spec.Invite: case spec.Invite:
usersToNotify = append(usersToNotify, targetUserID) usersToNotify = append(usersToNotify, targetUserID.String())
case spec.Join: case spec.Join:
// Manually append the new user's ID so they get notified // Manually append the new user's ID so they get notified
// along all members in the room // along all members in the room
usersToNotify = append(usersToNotify, targetUserID) usersToNotify = append(usersToNotify, targetUserID.String())
n._addJoinedUser(ev.RoomID(), targetUserID) n._addJoinedUser(ev.RoomID(), targetUserID.String())
case spec.Leave: case spec.Leave:
fallthrough fallthrough
case spec.Ban: case spec.Ban:
n._removeJoinedUser(ev.RoomID(), targetUserID) n._removeJoinedUser(ev.RoomID(), targetUserID.String())
}
} }
} }
} }

View file

@ -22,9 +22,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types" rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
} }
} }
type TestRoomServer struct{ api.SyncRoomserverAPI }
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
// Test that the current position is returned if a request is already behind. // Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) { func TestImmediateNotification(t *testing.T) {
n := NewNotifier() n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore) n.SetCurrentPosition(syncPositionBefore)
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
if err != nil { if err != nil {
@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) {
// Test that new events to a joined room unblocks the request. // Test that new events to a joined room unblocks the request.
func TestNewEventAndJoinedToRoom(t *testing.T) { func TestNewEventAndJoinedToRoom(t *testing.T) {
n := NewNotifier() n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore) n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
} }
func TestCorrectStream(t *testing.T) { func TestCorrectStream(t *testing.T) {
n := NewNotifier() n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore) n.SetCurrentPosition(syncPositionBefore)
stream := lockedFetchUserStream(n, bob, bobDev) stream := lockedFetchUserStream(n, bob, bobDev)
if stream.UserID != bob { if stream.UserID != bob {
@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) {
} }
func TestCorrectStreamWakeup(t *testing.T) { func TestCorrectStreamWakeup(t *testing.T) {
n := NewNotifier() n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore) n.SetCurrentPosition(syncPositionBefore)
awoken := make(chan string) awoken := make(chan string)
@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
// Test that an invite unblocks the request // Test that an invite unblocks the request
func TestNewInviteEventForUser(t *testing.T) { func TestNewInviteEventForUser(t *testing.T) {
n := NewNotifier() n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore) n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) {
// Test that all blocked requests get woken up on a new event. // Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) { func TestMultipleRequestWakeup(t *testing.T) {
n := NewNotifier() n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore) n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// listen as bob. Make bob leave room. Make alice send event to room. // listen as bob. Make bob leave room. Make alice send event to room.
// Make sure alice gets woken up only and not bob as well. // Make sure alice gets woken up only and not bob as well.
n := NewNotifier() n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore) n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},

View file

@ -85,9 +85,16 @@ func Context(
*filter.Rooms = append(*filter.Rooms, roomID) *filter.Rooms = append(*filter.Rooms, roomID)
} }
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Device UserID is invalid"),
}
}
ctx := req.Context() ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{} membershipRes := roomserver.QueryMembershipForUserResponse{}
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID}
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
logrus.WithError(err).Error("unable to query membership") logrus.WithError(err).Error("unable to query membership")
return util.JSONResponse{ return util.JSONResponse{
@ -217,12 +224,9 @@ func Context(
} }
} }
sender := spec.UserID{} ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID()) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
if err == nil && userID != nil { }, requestedEvent)
sender = *userID
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
response := ContextRespsonse{ response := ContextRespsonse{
Event: &ev, Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,

View file

@ -106,8 +106,17 @@ func GetEvent(
if err == nil && senderUserID != nil { if err == nil && senderUserID != nil {
sender = *senderUserID sender = *senderUserID
} }
sk := events[0].StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
} }
} }

View file

@ -59,14 +59,21 @@ func GetMemberships(
syncDB storage.Database, rsAPI api.SyncRoomserverAPI, syncDB storage.Database, rsAPI api.SyncRoomserverAPI,
joinedOnly bool, membership, notMembership *string, at string, joinedOnly bool, membership, notMembership *string, at string,
) util.JSONResponse { ) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Device UserID is invalid"),
}
}
queryReq := api.QueryMembershipForUserRequest{ queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: device.UserID, UserID: *userID,
} }
var queryRes api.QueryMembershipForUserResponse var queryRes api.QueryMembershipForUserResponse
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},

View file

@ -296,9 +296,13 @@ func OnIncomingMessagesRequest(
} }
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return resp, err
}
req := api.QueryMembershipForUserRequest{ req := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: userID, UserID: *fullUserID,
} }
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
return api.QueryMembershipForUserResponse{}, err return api.QueryMembershipForUserResponse{}, err

View file

@ -119,9 +119,18 @@ func Relations(
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
res.Chunk = append( res.Chunk = append(
res.Chunk, res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk),
) )
} }

View file

@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
results = append(results, Result{ results = append(results, Result{
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos, ProfileInfo: profileInfos,
}, },
Rank: eventScore[event.EventID()].Score, Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
}) })
roomGroup := groups[event.RoomID()] roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID()) roomGroup.Results = append(roomGroup.Results, event.EventID())

View file

@ -507,8 +507,20 @@ func (d *Database) CleanSendToDeviceUpdates(
// getMembershipFromEvent returns the value of content.membership iff the event is a state event // getMembershipFromEvent returns the value of content.membership iff the event is a state event
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. // with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
func getMembershipFromEvent(ev gomatrixserverlib.PDU, userID string) (string, string) { func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) {
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { if ev.StateKey() == nil || *ev.StateKey() == "" {
return "", ""
}
fullUser, err := spec.NewUserID(userID, true)
if err != nil {
return "", ""
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser)
if err != nil {
return "", ""
}
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) {
return "", "" return "", ""
} }
membership, err := ev.Membership() membership, err := ev.Membership()

View file

@ -430,7 +430,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, ev := range stateStreamEvents { for _, ev := range stateStreamEvents {
// Look for our membership in the state events and skip over any // Look for our membership in the state events and skip over any
// membership events that are not related to us. // membership events that are not related to us.
membership, prevMembership := getMembershipFromEvent(ev.PDU, userID) membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI)
if membership == "" { if membership == "" {
continue continue
} }
@ -556,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
for roomID, stateStreamEvents := range state { for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents { for _, ev := range stateStreamEvents {
if membership, _ := getMembershipFromEvent(ev.PDU, userID); membership != "" { if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" {
if membership != spec.Join { // We've already added full state for all joined rooms above. if membership != spec.Join { // We've already added full state for all joined rooms above.
deltas[roomID] = types.StateDelta{ deltas[roomID] = types.StateDelta{
Membership: membership, Membership: membership,

View file

@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync(
user = *sender user = *sender
} }
sk := inviteEvent.StateKey()
if sk != nil && *sk != "" {
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
// skip ignored user events // skip ignored user events
if _, ok := req.IgnoredUsers.List[user.String()]; ok { if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue continue
} }
ir := types.NewInviteResponse(inviteEvent, user) ir := types.NewInviteResponse(inviteEvent, user, sk)
req.Response.Rooms.Invite[roomID] = ir req.Response.Rooms.Invite[roomID] = ir
} }

View file

@ -605,13 +605,17 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// If this is a gapped incremental sync, we still want this membership // If this is a gapped incremental sync, we still want this membership
isGappedIncremental := limited && incremental isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list // We want this users membership event, keep it in the list
stateKey := *event.StateKey() userID := ""
if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID { stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey()))
if err == nil && stateKeyUserID != nil {
userID = stateKeyUserID.String()
}
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {
newStateEvents = append(newStateEvents, event) newStateEvents = append(newStateEvents, event)
if !stateFilter.IncludeRedundantMembers { if !stateFilter.IncludeRedundantMembers {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID())
} }
delete(timelineUsers, stateKey) delete(timelineUsers, userID)
} }
} else { } else {
newStateEvents = append(newStateEvents, event) newStateEvents = append(newStateEvents, event)

View file

@ -60,7 +60,7 @@ func AddPublicRoutes(
} }
eduCache := caching.NewTypingCache() eduCache := caching.NewTypingCache()
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier(rsAPI)
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier) streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background())) notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil { if err = notifier.Load(context.Background(), syncDB); err != nil {

View file

@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
evs = append(evs, ToClientEvent(se, format, sender))
sk := se.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
evs = append(evs, ToClientEvent(se, format, sender, sk))
} }
return evs return evs
} }
// ToClientEvent converts a single server event to a client event. // ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent {
ce := ClientEvent{ ce := ClientEvent{
Content: spec.RawJSON(se.Content()), Content: spec.RawJSON(se.Content()),
Sender: sender.String(), Sender: sender.String(),
Type: se.Type(), Type: se.Type(),
StateKey: se.StateKey(), StateKey: stateKey,
Unsigned: spec.RawJSON(se.Unsigned()), Unsigned: spec.RawJSON(se.Unsigned()),
OriginServerTS: se.OriginServerTS(), OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(), EventID: se.EventID(),
@ -77,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
} }
return ce return ce
} }
// ToClientEvent converts a single server event to a client event.
// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
sender := spec.UserID{}
userID, err := userIDQuery(event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
return ToClientEvent(event, FormatAll, sender, sk)
}

View file

@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil { if err != nil {
t.Fatalf("failed to create userID: %s", err) t.Fatalf("failed to create userID: %s", err)
} }
ce := ToClientEvent(ev, FormatAll, *userID) sk := ""
ce := ToClientEvent(ev, FormatAll, *userID, &sk)
if ce.EventID != ev.EventID() { if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
} }
@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create userID: %s", err) t.Fatalf("failed to create userID: %s", err)
} }
ce := ToClientEvent(ev, FormatSync, *userID) sk := ""
ce := ToClientEvent(ev, FormatSync, *userID, &sk)
if ce.RoomID != "" { if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
} }

View file

@ -539,7 +539,7 @@ type InviteResponse struct {
} }
// NewInviteResponse creates an empty response with initialised arrays. // NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse {
res := InviteResponse{} res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{} res.InviteState.Events = []json.RawMessage{}
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe
// Then we'll see if we can create a partial of the invite event itself. // Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite. // This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey)
inviteEvent.Unsigned = nil inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil { if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev) res.InviteState.Events = append(res.InviteState.Events, ev)

View file

@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
if err != nil {
t.Fatal(err)
}
skString := skUserID.String()
sk := &skString
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk)
j, err := json.Marshal(res) j, err := json.Marshal(res)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -306,7 +306,16 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
if queryErr == nil && userID != nil { if queryErr == nil && userID != nil {
sender = *userID sender = *userID
} }
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if queryErr == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk)
var member *localMembership var member *localMembership
member, err = newLocalMembership(&cevent) member, err = newLocalMembership(&cevent)
if err != nil { if err != nil {
@ -539,12 +548,21 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if queryErr == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
n := &api.Notification{ n := &api.Notification{
Actions: actions, Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the // UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which // fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync. // matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk),
// TODO: this is per-device, but it's not part of the primary // TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't // key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it // make sense. What is this supposed to be? Sytests require it
@ -792,10 +810,20 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
Type: event.Type(), Type: event.Type(),
}, },
} }
if mem, err := event.Membership(); err == nil { if mem, memberErr := event.Membership(); memberErr == nil {
req.Notification.Membership = mem req.Notification.Membership = mem
} }
if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { userID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName), true)
if err != nil {
logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart)
return nil, err
}
localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID)
if err != nil {
logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID())
return nil, err
}
if event.StateKey() != nil && *event.StateKey() == string(localSender) {
req.Notification.UserIsTarget = true req.Notification.UserIsTarget = true
} }
} }

View file

@ -104,8 +104,9 @@ func TestNotifyUserCountsAsync(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
sk := ""
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk),
}); err != nil { }); err != nil {
t.Error(err) t.Error(err)
} }