Use SenderID Type (#3105)

This commit is contained in:
devonh 2023-06-07 17:14:35 +00:00 committed by GitHub
parent 7a1fd7f512
commit 8ea1a11105
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
60 changed files with 502 additions and 275 deletions

View file

@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) sendEvents(
// Create the transaction body. // Create the transaction body.
transaction, err := json.Marshal( transaction, err := json.Marshal(
ApplicationServiceTransaction{ ApplicationServiceTransaction{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
}, },

View file

@ -338,7 +338,21 @@ func SetVisibility(
// NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event
power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU)
if power.UserLevel(dev.UserID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { fullUserID, err := spec.NewUserID(dev.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"), JSON: spec.Forbidden("userID doesn't have power level to change visibility"),

View file

@ -66,7 +66,21 @@ func SendBan(
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
allowedToBan := pl.UserLevel(device.UserID) >= pl.Ban fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
}
}
allowedToBan := pl.UserLevel(senderID) >= pl.Ban
if !allowedToBan { if !allowedToBan {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -142,7 +156,21 @@ func SendKick(
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
allowedToKick := pl.UserLevel(device.UserID) >= pl.Kick fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
allowedToKick := pl.UserLevel(senderID) >= pl.Kick
if !allowedToKick { if !allowedToKick {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -151,7 +179,7 @@ func SendKick(
} }
var queryRes roomserverAPI.QueryMembershipForUserResponse var queryRes roomserverAPI.QueryMembershipForUserResponse
err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,
UserID: body.UserID, UserID: body.UserID,
}, &queryRes) }, &queryRes)
@ -319,7 +347,7 @@ func buildMembershipEventDirect(
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
) (*types.HeaderedEvent, error) { ) (*types.HeaderedEvent, error) {
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: sender, SenderID: sender,
RoomID: roomID, RoomID: roomID,
Type: "m.room.member", Type: "m.room.member",
StateKey: &targetUserID, StateKey: &targetUserID,

View file

@ -363,12 +363,21 @@ func buildMembershipEvents(
) ([]*types.HeaderedEvent, error) { ) ([]*types.HeaderedEvent, error) {
evs := []*types.HeaderedEvent{} evs := []*types.HeaderedEvent{}
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
senderIDString := string(senderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: senderIDString,
RoomID: roomID, RoomID: roomID,
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &senderIDString,
} }
content := gomatrixserverlib.MemberContent{ content := gomatrixserverlib.MemberContent{
@ -378,7 +387,7 @@ func buildMembershipEvents(
content.DisplayName = newProfile.DisplayName content.DisplayName = newProfile.DisplayName
content.AvatarURL = newProfile.AvatarURL content.AvatarURL = newProfile.AvatarURL
if err := proto.SetContent(content); err != nil { if err = proto.SetContent(content); err != nil {
return nil, err return nil, err
} }

View file

@ -73,10 +73,25 @@ func SendRedaction(
} }
} }
fullUserID, userIDErr := spec.NewUserID(device.UserID, true)
if userIDErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
if queryErr != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to redact"),
}
}
// "Users may redact their own events, and any user with a power level greater than or equal // "Users may redact their own events, and any user with a power level greater than or equal
// to the redact power level of the room may redact events there" // to the redact power level of the room may redact events there"
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
allowedToRedact := ev.SenderID() == device.UserID // TODO: Should replace device.UserID with device...PerRoomKey allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey
if !allowedToRedact { if !allowedToRedact {
plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,
@ -97,7 +112,7 @@ func SendRedaction(
), ),
} }
} }
allowedToRedact = pl.UserLevel(device.UserID) >= pl.Redact allowedToRedact = pl.UserLevel(senderID) >= pl.Redact
} }
if !allowedToRedact { if !allowedToRedact {
return util.JSONResponse{ return util.JSONResponse{
@ -114,10 +129,10 @@ func SendRedaction(
// create the new event and set all the fields we can // create the new event and set all the fields we can
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: device.UserID, SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Redacts: eventID, Redacts: eventID,
} }
err := proto.SetContent(r) err := proto.SetContent(r)
if err != nil { if err != nil {

View file

@ -266,16 +266,29 @@ func generateSendEvent(
evTime time.Time, evTime time.Time,
) (gomatrixserverlib.PDU, *util.JSONResponse) { ) (gomatrixserverlib.PDU, *util.JSONResponse) {
// parse the incoming http request // parse the incoming http request
userID := device.UserID fullUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Bad userID"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Unable to find senderID for user"),
}
}
// create the new event and set all the fields we can // create the new event and set all the fields we can
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Type: eventType, Type: eventType,
StateKey: stateKey, StateKey: stateKey,
} }
err := proto.SetContent(r) err = proto.SetContent(r)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("proto.SetContent failed") util.GetLogger(ctx).WithError(err).Error("proto.SetContent failed")
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
@ -331,7 +344,7 @@ func generateSendEvent(
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{

View file

@ -355,8 +355,16 @@ func emit3PIDInviteEvent(
rsAPI api.ClientRoomserverAPI, rsAPI api.ClientRoomserverAPI,
evTime time.Time, evTime time.Time,
) error { ) error {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return err
}
sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
if err != nil {
return err
}
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Sender: device.UserID, SenderID: string(sender),
RoomID: roomID, RoomID: roomID,
Type: "m.room.third_party_invite", Type: "m.room.third_party_invite",
StateKey: &res.Token, StateKey: &res.Token,
@ -370,7 +378,7 @@ func emit3PIDInviteEvent(
PublicKeys: res.PublicKeys, PublicKeys: res.PublicKeys,
} }
if err := proto.SetContent(content); err != nil { if err = proto.SetContent(content); err != nil {
return err return err
} }

View file

@ -183,7 +183,7 @@ func main() {
fmt.Println("Resolving state") fmt.Println("Resolving state")
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID, senderID string) (*spec.UserID, error) { gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) return roomserverDB.GetUserIDForSender(ctx, roomID, senderID)
}, },
) )

View file

@ -36,8 +36,12 @@ type fedRoomserverAPI struct {
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
} }
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
}
func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
return spec.SenderID(userID.String()), nil
} }
// PerformJoin will call this function // PerformJoin will call this function
@ -115,12 +119,13 @@ func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roo
defer f.fedClientMutex.Unlock() defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == roomID { if r.ID == roomID {
senderIDString := userID
res.RoomVersion = r.Version res.RoomVersion = r.Version
res.JoinEvent = gomatrixserverlib.ProtoEvent{ res.JoinEvent = gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: senderIDString,
RoomID: roomID, RoomID: roomID,
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &senderIDString,
Content: spec.RawJSON([]byte(`{"membership":"join"}`)), Content: spec.RawJSON([]byte(`{"membership":"join"}`)),
PrevEvents: r.ForwardExtremities(), PrevEvents: r.ForwardExtremities(),
} }

View file

@ -154,9 +154,14 @@ func (r *FederationInternalAPI) performJoinUsingServer(
if err != nil { if err != nil {
return err return err
} }
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, roomID, *user)
if err != nil {
return err
}
joinInput := gomatrixserverlib.PerformJoinInput{ joinInput := gomatrixserverlib.PerformJoinInput{
UserID: user, UserID: user,
SenderID: senderID,
RoomID: room, RoomID: room,
ServerName: serverName, ServerName: serverName,
Content: content, Content: content,
@ -164,10 +169,10 @@ func (r *FederationInternalAPI) performJoinUsingServer(
PrivateKey: r.cfg.Matrix.PrivateKey, PrivateKey: r.cfg.Matrix.PrivateKey,
KeyID: r.cfg.Matrix.KeyID, KeyID: r.cfg.Matrix.KeyID,
KeyRing: r.keyRing, KeyRing: r.keyRing,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID, senderID string) (*spec.UserID, error) { EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
} }
@ -363,7 +368,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
userIDProvider := func(roomID, senderID string) (*spec.UserID, error) { userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
} }
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
@ -414,7 +419,7 @@ func (r *FederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) userID, err := spec.NewUserID(request.UserID, true)
if err != nil { if err != nil {
return err return err
} }
@ -433,7 +438,7 @@ func (r *FederationInternalAPI) PerformLeave(
// request. // request.
respMakeLeave, err := r.federation.MakeLeave( respMakeLeave, err := r.federation.MakeLeave(
ctx, ctx,
origin, userID.Domain(),
serverName, serverName,
request.RoomID, request.RoomID,
request.UserID, request.UserID,
@ -454,9 +459,14 @@ func (r *FederationInternalAPI) PerformLeave(
// Set all the fields to be what they should be, this should be a no-op // Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd" // but it's possible that the remote server returned us something "odd"
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID)
if err != nil {
return err
}
senderIDString := string(senderID)
respMakeLeave.LeaveEvent.Type = spec.MRoomMember respMakeLeave.LeaveEvent.Type = spec.MRoomMember
respMakeLeave.LeaveEvent.Sender = request.UserID respMakeLeave.LeaveEvent.SenderID = senderIDString
respMakeLeave.LeaveEvent.StateKey = &request.UserID respMakeLeave.LeaveEvent.StateKey = &senderIDString
respMakeLeave.LeaveEvent.RoomID = request.RoomID respMakeLeave.LeaveEvent.RoomID = request.RoomID
respMakeLeave.LeaveEvent.Redacts = "" respMakeLeave.LeaveEvent.Redacts = ""
leaveEB := verImpl.NewEventBuilderFromProtoEvent(&respMakeLeave.LeaveEvent) leaveEB := verImpl.NewEventBuilderFromProtoEvent(&respMakeLeave.LeaveEvent)
@ -478,7 +488,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Build the leave event. // Build the leave event.
event, err := leaveEB.Build( event, err := leaveEB.Build(
time.Now(), time.Now(),
origin, userID.Domain(),
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
) )
@ -490,7 +500,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Try to perform a send_leave using the newly built event. // Try to perform a send_leave using the newly built event.
err = r.federation.SendLeave( err = r.federation.SendLeave(
ctx, ctx,
origin, userID.Domain(),
serverName, serverName,
event, event,
) )

View file

@ -95,7 +95,7 @@ func InviteV2(
StateQuerier: rsAPI.StateQuerier(), StateQuerier: rsAPI.StateQuerier(),
InviteEvent: inviteReq.Event(), InviteEvent: inviteReq.Event(),
StrippedState: inviteReq.InviteRoomState(), StrippedState: inviteReq.InviteRoomState(),
UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }
@ -188,7 +188,7 @@ func InviteV1(
StateQuerier: rsAPI.StateQuerier(), StateQuerier: rsAPI.StateQuerier(),
InviteEvent: event, InviteEvent: event,
StrippedState: strippedState, StrippedState: strippedState,
UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }

View file

@ -55,7 +55,7 @@ func MakeJoin(
RoomID: roomID.String(), RoomID: roomID.String(),
} }
res := api.QueryServerJoinedToRoomResponse{} res := api.QueryServerJoinedToRoomResponse{}
if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -64,26 +64,26 @@ func MakeJoin(
} }
createJoinTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { createJoinTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) {
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil { if signErr != nil {
util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination())
return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination()))
} }
queryRes := api.QueryLatestEventsAndStateResponse{ queryRes := api.QueryLatestEventsAndStateResponse{
RoomVersion: roomVersion, RoomVersion: roomVersion,
} }
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) event, signErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes)
switch e := err.(type) { switch e := signErr.(type) {
case nil: case nil:
case eventutil.ErrRoomNoExists: case eventutil.ErrRoomNoExists:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.NotFound("Room does not exist") return nil, nil, spec.NotFound("Room does not exist")
case gomatrixserverlib.BadJSONError: case gomatrixserverlib.BadJSONError:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.BadJSON(e.Error()) return nil, nil, spec.BadJSON(e.Error())
default: default:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.InternalServerError{} return nil, nil, spec.InternalServerError{}
} }
@ -98,9 +98,19 @@ func MakeJoin(
Roomserver: rsAPI, Roomserver: rsAPI,
} }
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
input := gomatrixserverlib.HandleMakeJoinInput{ input := gomatrixserverlib.HandleMakeJoinInput{
Context: httpReq.Context(), Context: httpReq.Context(),
UserID: userID, UserID: userID,
SenderID: senderID,
RoomID: roomID, RoomID: roomID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RemoteVersions: remoteVersions, RemoteVersions: remoteVersions,
@ -108,7 +118,7 @@ func MakeJoin(
LocalServerName: cfg.Matrix.ServerName, LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerInRoom: res.RoomExists && res.IsInRoom,
RoomQuerier: &roomQuerier, RoomQuerier: &roomQuerier,
UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
BuildEventTemplate: createJoinTemplate, BuildEventTemplate: createJoinTemplate,
@ -205,7 +215,7 @@ func SendJoin(
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys, Verifier: keys,
MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI},
UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }

View file

@ -50,7 +50,7 @@ func MakeLeave(
RoomID: roomID.String(), RoomID: roomID.String(),
} }
res := api.QueryServerJoinedToRoomResponse{} res := api.QueryServerJoinedToRoomResponse{}
if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -59,24 +59,24 @@ func MakeLeave(
} }
createLeaveTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { createLeaveTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) {
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil { if signErr != nil {
util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination())
return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination()))
} }
queryRes := api.QueryLatestEventsAndStateResponse{} queryRes := api.QueryLatestEventsAndStateResponse{}
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) event, buildErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes)
switch e := err.(type) { switch e := buildErr.(type) {
case nil: case nil:
case eventutil.ErrRoomNoExists: case eventutil.ErrRoomNoExists:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.NotFound("Room does not exist") return nil, nil, spec.NotFound("Room does not exist")
case gomatrixserverlib.BadJSONError: case gomatrixserverlib.BadJSONError:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.BadJSON(e.Error()) return nil, nil, spec.BadJSON(e.Error())
default: default:
util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed")
return nil, nil, spec.InternalServerError{} return nil, nil, spec.InternalServerError{}
} }
@ -87,15 +87,25 @@ func MakeLeave(
return event, stateEvents, nil return event, stateEvents, nil
} }
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
input := gomatrixserverlib.HandleMakeLeaveInput{ input := gomatrixserverlib.HandleMakeLeaveInput{
UserID: userID, UserID: userID,
SenderID: senderID,
RoomID: roomID, RoomID: roomID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RequestOrigin: request.Origin(), RequestOrigin: request.Origin(),
LocalServerName: cfg.Matrix.ServerName, LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerInRoom: res.RoomExists && res.IsInRoom,
BuildEventTemplate: createLeaveTemplate, BuildEventTemplate: createLeaveTemplate,
UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }
@ -216,7 +226,7 @@ func SendLeave(
JSON: spec.BadJSON("No state key was provided in the leave event."), JSON: spec.BadJSON("No state key was provided in the leave event."),
} }
} }
if !event.StateKeyEquals(event.SenderID()) { if !event.StateKeyEquals(string(event.SenderID())) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("Event state key must match the event sender."), JSON: spec.BadJSON("Event state key must match the event sender."),

View file

@ -140,22 +140,24 @@ func ExchangeThirdPartyInvite(
} }
} }
_, senderDomain, err := cfg.Matrix.SplitLocalID('@', proto.Sender) userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID))
if err != nil { if err != nil || userID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid sender ID: " + err.Error()), JSON: spec.BadJSON("Invalid sender ID"),
} }
} }
senderDomain := userID.Domain()
// Check that the state key is correct. // Check that the state key is correct.
_, targetDomain, err := gomatrixserverlib.SplitID('@', *proto.StateKey) targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey))
if err != nil { if err != nil || targetUserID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("The event's state key isn't a Matrix user ID"), JSON: spec.BadJSON("The event's state key isn't a Matrix user ID"),
} }
} }
targetDomain := targetUserID.Domain()
// Check that the target user is from the requesting homeserver. // Check that the target user is from the requesting homeserver.
if targetDomain != request.Origin() { if targetDomain != request.Origin() {
@ -271,7 +273,7 @@ func createInviteFrom3PIDInvite(
// Build the event // Build the event
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Type: "m.room.member", Type: "m.room.member",
Sender: inv.Sender, SenderID: inv.Sender,
RoomID: inv.RoomID, RoomID: inv.RoomID,
StateKey: &inv.MXID, StateKey: &inv.MXID,
} }

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16

4
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 h1:6SixhMmB5Ir10xUJ6zh3A4NBxSaZCSz2s5U63Wg0eEU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=

View file

@ -28,7 +28,7 @@ type EvaluationContext interface {
// HasPowerLevel returns whether the user has at least the given // HasPowerLevel returns whether the user has at least the given
// power in the room of the current event. // power in the room of the current event.
HasPowerLevel(userID, levelKey string) (bool, error) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error)
} }
// A kindAndRules is just here to simplify iteration of the (ordered) // A kindAndRules is just here to simplify iteration of the (ordered)

View file

@ -8,8 +8,8 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
func TestRuleSetEvaluatorMatchEvent(t *testing.T) { func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
@ -158,8 +158,8 @@ type fakeEvaluationContext struct{ memberCount int }
func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" } func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
func (f fakeEvaluationContext) RoomMemberCount() (int, error) { return f.memberCount, nil } func (f fakeEvaluationContext) RoomMemberCount() (int, error) { return f.memberCount, nil }
func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) { func (fakeEvaluationContext) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) {
return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil return senderID == "@poweruser:example.com" && levelKey == "powerlevel", nil
} }
func TestPatternMatches(t *testing.T) { func TestPatternMatches(t *testing.T) {

View file

@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
} }
continue continue
} }
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())

View file

@ -70,8 +70,8 @@ type FakeRsAPI struct {
bannedFromRoom bool bannedFromRoom bool
} }
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
func (r *FakeRsAPI) QueryRoomVersionForRoom( func (r *FakeRsAPI) QueryRoomVersionForRoom(
@ -642,8 +642,8 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
} }
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
func (t *testRoomserverAPI) InputRoomEvents( func (t *testRoomserverAPI) InputRoomEvents(

View file

@ -14,7 +14,11 @@
package api package api
import "regexp" import (
"regexp"
"github.com/matrix-org/gomatrixserverlib/spec"
)
// SetRoomAliasRequest is a request to SetRoomAlias // SetRoomAliasRequest is a request to SetRoomAlias
type SetRoomAliasRequest struct { type SetRoomAliasRequest struct {
@ -62,7 +66,7 @@ type GetAliasesForRoomIDResponse struct {
// RemoveRoomAliasRequest is a request to RemoveRoomAlias // RemoveRoomAliasRequest is a request to RemoveRoomAlias
type RemoveRoomAliasRequest struct { type RemoveRoomAliasRequest struct {
// ID of the user removing the alias // ID of the user removing the alias
SenderID string `json:"user_id"` SenderID spec.SenderID `json:"user_id"`
// The room alias to remove // The room alias to remove
Alias string `json:"alias"` Alias string `json:"alias"`
} }

View file

@ -77,8 +77,8 @@ type InputRoomEventsAPI interface {
} }
type QuerySenderIDAPI interface { type QuerySenderIDAPI interface {
QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
} }
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.

View file

@ -130,7 +130,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
} }
sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID)
if err != nil { if err != nil || sender == nil {
return fmt.Errorf("r.QueryUserIDForSender: %w", err) return fmt.Errorf("r.QueryUserIDForSender: %w", err)
} }
virtualHost := sender.Domain() virtualHost := sender.Domain()
@ -141,7 +141,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err)
} }
if creatorID != request.SenderID { if spec.SenderID(creatorID) != request.SenderID {
var plEvent *types.HeaderedEvent var plEvent *types.HeaderedEvent
var pls *gomatrixserverlib.PowerLevelContent var pls *gomatrixserverlib.PowerLevelContent
@ -173,23 +173,24 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return err return err
} }
sender := request.SenderID senderID := request.SenderID
if request.SenderID != ev.SenderID() { if request.SenderID != ev.SenderID() {
sender = ev.SenderID() senderID = ev.SenderID()
} }
sender, err := r.QueryUserIDForSender(ctx, roomID, senderID)
_, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) if err != nil || sender == nil {
if err != nil {
return err return err
} }
senderDomain := sender.Domain()
identity, err := r.Cfg.Global.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Global.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
return err return err
} }
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Sender: sender, SenderID: string(senderID),
RoomID: ev.RoomID(), RoomID: ev.RoomID(),
Type: ev.Type(), Type: ev.Type(),
StateKey: ev.StateKey(), StateKey: ev.StateKey(),

View file

@ -76,7 +76,7 @@ func CheckForSoftFail(
} }
// Check if the event is allowed. // Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return db.GetUserIDForSender(ctx, roomID, senderID) return db.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
// return true, nil // return true, nil
@ -139,8 +139,8 @@ func (ae *authEvents) JoinRules() (gomatrixserverlib.PDU, error) {
} }
// Memmber implements gomatrixserverlib.AuthEventProvider // Memmber implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) Member(stateKey string) (gomatrixserverlib.PDU, error) { func (ae *authEvents) Member(stateKey spec.SenderID) (gomatrixserverlib.PDU, error) {
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil return ae.lookupEvent(types.MRoomMemberNID, string(stateKey)), nil
} }
// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider // ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider

View file

@ -282,7 +282,7 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then // Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted. // we consider the event to be "rejected" — it will still be persisted.
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
isRejected = true isRejected = true
@ -501,7 +501,7 @@ func (r *Inputer) processRoomEvent(
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
oldRoomID := event.RoomID() oldRoomID := event.RoomID()
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.SenderID()) return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, string(event.SenderID()))
} }
// processStateBefore works out what the state is before the event and // processStateBefore works out what the state is before the event and
@ -587,7 +587,7 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents( stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent), gomatrixserverlib.ToPDUs(stateBeforeEvent),
) )
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) { if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); rejectionErr != nil { }); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
@ -700,7 +700,7 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply // Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem // skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway. // if a critical event is missing anyway.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue nextAuthEvent continue nextAuthEvent
@ -718,7 +718,7 @@ nextAuthEvent:
} }
// Check if the auth event should be rejected. // Check if the auth event should be rejected.
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) { err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}) })
if isRejected = err != nil; isRejected { if isRejected = err != nil; isRejected {
@ -875,7 +875,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
RoomID: event.RoomID(), RoomID: event.RoomID(),
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: &stateKey, StateKey: &stateKey,
Sender: stateKey, SenderID: stateKey,
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }

View file

@ -58,7 +58,9 @@ func Test_EventAuth(t *testing.T) {
} }
// Finally check that the event is NOT allowed // Finally check that the event is NOT allowed
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}); err == nil {
t.Fatalf("event should not be allowed, but it was") t.Fatalf("event should not be allowed, but it was")
} }
} }

View file

@ -473,7 +473,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...) stateEventList = append(stateEventList, state.StateEvents...)
} }
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) { roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.db.GetUserIDForSender(ctx, roomID, senderID)
}, },
) )
@ -482,7 +482,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
} }
// apply the current event // apply the current event
retryAllowedState: retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) { if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.db.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
@ -569,7 +569,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed. // will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.db.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue
@ -660,7 +660,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(), StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(), AuthEvents: state.GetAuthEvents(),
}, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) { }, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.db.GetUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {
@ -897,7 +897,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
} }
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.db.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())

View file

@ -96,14 +96,15 @@ func (r *Admin) PerformAdminEvacuateRoom(
RoomID: roomID, RoomID: roomID,
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: &stateKey, StateKey: &stateKey,
Sender: stateKey, SenderID: stateKey,
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }
_, senderDomain, err = gomatrixserverlib.SplitID('@', fledglingEvent.Sender) userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID))
if err != nil { if err != nil || userID == nil {
continue continue
} }
senderDomain = userID.Domain()
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
return nil, err return nil, err
@ -233,10 +234,11 @@ func (r *Admin) PerformAdminDownloadState(
ctx context.Context, ctx context.Context,
roomID, userID string, serverName spec.ServerName, roomID, userID string, serverName spec.ServerName,
) error { ) error {
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) fullUserID, err := spec.NewUserID(userID, true)
if err != nil { if err != nil {
return err return err
} }
senderDomain := fullUserID.Domain()
roomInfo, err := r.DB.RoomInfo(ctx, roomID) roomInfo, err := r.DB.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
@ -262,7 +264,7 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
} }
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue
@ -270,7 +272,7 @@ func (r *Admin) PerformAdminDownloadState(
authEventMap[authEvent.EventID()] = authEvent authEventMap[authEvent.EventID()] = authEvent
} }
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue
@ -291,11 +293,15 @@ func (r *Admin) PerformAdminDownloadState(
stateIDs = append(stateIDs, stateEvent.EventID()) stateIDs = append(stateIDs, stateEvent.EventID())
} }
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return err
}
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Type: "org.matrix.dendrite.state_download", Type: "org.matrix.dendrite.state_download",
Sender: userID, SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Content: spec.RawJSON("{}"), Content: spec.RawJSON("{}"),
} }
eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto) eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto)

View file

@ -121,7 +121,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events" // Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) { r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}, },
) )
@ -212,7 +212,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue continue
} }
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) { result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {

View file

@ -270,11 +270,19 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
var builtEvents []*types.HeaderedEvent var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
for i, e := range eventsToMake { for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
Sender: userID.String(), SenderID: string(senderID),
RoomID: roomID.String(), RoomID: roomID.String(),
Type: e.Type, Type: e.Type,
StateKey: &e.StateKey, StateKey: &e.StateKey,
@ -308,7 +316,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return c.DB.GetUserIDForSender(ctx, roomID, senderID) return c.DB.GetUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
@ -409,11 +417,28 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
// Process the invites. // Process the invites.
var inviteEvent *types.HeaderedEvent var inviteEvent *types.HeaderedEvent
for _, invitee := range createRequest.InvitedUsers { for _, invitee := range createRequest.InvitedUsers {
inviteeUserID, userIDErr := spec.NewUserID(invitee, true)
if userIDErr != nil {
util.GetLogger(ctx).WithError(userIDErr).Error("invalid UserID")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID)
if queryErr != nil {
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
inviteeString := string(inviteeSenderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID.String(), SenderID: string(senderID),
RoomID: roomID.String(), RoomID: roomID.String(),
Type: "m.room.member", Type: "m.room.member",
StateKey: &invitee, StateKey: &inviteeString,
} }
content := gomatrixserverlib.MemberContent{ content := gomatrixserverlib.MemberContent{

View file

@ -98,7 +98,7 @@ func (r *Inviter) ProcessInviteMembership(
var outputUpdates []api.OutputEvent var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater var updater *shared.MembershipUpdater
userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey()) userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
if err != nil { if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
} }
@ -148,15 +148,21 @@ func (r *Inviter) PerformInvite(
return err return err
} }
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser)
if err != nil {
return fmt.Errorf("failed looking up senderID for invited user")
}
input := gomatrixserverlib.PerformInviteInput{ input := gomatrixserverlib.PerformInviteInput{
RoomID: *validRoomID, RoomID: *validRoomID,
InviteEvent: event.PDU, InviteEvent: event.PDU,
InvitedUser: *invitedUser, InvitedUser: *invitedUser,
InvitedSenderID: invitedSenderID,
IsTargetLocal: isTargetLocal, IsTargetLocal: isTargetLocal,
StrippedState: req.InviteRoomState, StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB}, StateQuerier: &QueryState{r.DB},
UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}, },
} }

View file

@ -175,15 +175,20 @@ func (r *Joiner) performJoinRoomByID(
} }
// Prepare the template for the join event. // Prepare the template for the join event.
userID := req.UserID userID, err := spec.NewUserID(req.UserID, true)
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", userID, err)} return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
} }
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID)
if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
}
senderIDString := string(senderID)
userDomain := userID.Domain()
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
Sender: userID, SenderID: senderIDString,
StateKey: &userID, StateKey: &senderIDString,
RoomID: req.RoomIDOrAlias, RoomID: req.RoomIDOrAlias,
Redacts: "", Redacts: "",
} }
@ -295,7 +300,7 @@ func (r *Joiner) performJoinRoomByID(
// is really no harm in just sending another membership event. // is really no harm in just sending another membership event.
membershipReq := &api.QueryMembershipForUserRequest{ membershipReq := &api.QueryMembershipForUserRequest{
RoomID: req.RoomIDOrAlias, RoomID: req.RoomIDOrAlias,
UserID: userID, UserID: userID.String(),
} }
membershipRes := &api.QueryMembershipForUserResponse{} membershipRes := &api.QueryMembershipForUserResponse{}
_ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes)

View file

@ -152,11 +152,19 @@ func (r *Leaver) performLeaveRoomByID(
} }
// Prepare the template for the leave event. // Prepare the template for the leave event.
userID := req.UserID fullUserID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return nil, err
}
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID)
if err != nil {
return nil, err
}
senderIDString := string(senderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
Sender: userID, SenderID: senderIDString,
StateKey: &userID, StateKey: &senderIDString,
RoomID: req.RoomID, RoomID: req.RoomID,
Redacts: "", Redacts: "",
} }
@ -168,10 +176,7 @@ func (r *Leaver) performLeaveRoomByID(
} }
// Get the sender domain. // Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', proto.Sender) senderDomain := fullUserID.Domain()
if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", proto.Sender)
}
// We know that the user is in the room at this point so let's build // We know that the user is in the room at this point so let's build
// a leave event. // a leave event.

View file

@ -175,8 +175,16 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to get old room aliases: %w", err) return fmt.Errorf("Failed to get old room aliases: %w", err)
} }
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return fmt.Errorf("Failed to get userID: %w", err)
}
senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return fmt.Errorf("Failed to get senderID: %w", err)
}
for _, alias := range aliasRes.Aliases { for _, alias := range aliasRes.Aliases {
removeAliasReq := api.RemoveRoomAliasRequest{SenderID: userID, Alias: alias} removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{} removeAliasRes := api.RemoveRoomAliasResponse{}
if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil {
return fmt.Errorf("Failed to remove old room alias: %w", err) return fmt.Errorf("Failed to remove old room alias: %w", err)
@ -287,7 +295,15 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
} }
// Check for power level required to send tombstone event (marks the current room as obsolete), // Check for power level required to send tombstone event (marks the current room as obsolete),
// if not found, use the StateDefault power level // if not found, use the StateDefault power level
return pl.UserLevel(userID) >= pl.EventLevel("m.room.tombstone", true) fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return false
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return false
}
return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true)
} }
// nolint:gocyclo // nolint:gocyclo
@ -383,7 +399,16 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
util.GetLogger(ctx).WithError(err).Error() util.GetLogger(ctx).WithError(err).Error()
return nil, fmt.Errorf("Power level event content was invalid") return nil, fmt.Errorf("Power level event content was invalid")
} }
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, userID)
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID)
// Now do the join rules event, same as the create and membership // Now do the join rules event, same as the create and membership
// events. We'll set a sane default of "invite" so that if the // events. We'll set a sane default of "invite" so that if the
@ -452,8 +477,16 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
for i, e := range eventsToMake { for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
fullUserID, userIDErr := spec.NewUserID(userID, true)
if userIDErr != nil {
return userIDErr
}
senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID)
if queryErr != nil {
return queryErr
}
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: string(senderID),
RoomID: newRoomID, RoomID: newRoomID,
Type: e.Type, Type: e.Type,
StateKey: &e.StateKey, StateKey: &e.StateKey,
@ -484,7 +517,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
} }
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
@ -530,21 +563,26 @@ func (r *Upgrader) makeTombstoneEvent(
} }
func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) {
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Sender: userID, SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Type: event.Type, Type: event.Type,
StateKey: &event.StateKey, StateKey: &event.StateKey,
} }
err := proto.SetContent(event.Content) err = proto.SetContent(event.Content)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err) return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err)
} }
// Get the sender domain. // Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', proto.Sender) senderDomain := fullUserID.Domain()
if serr != nil {
return nil, fmt.Errorf("Failed to split user ID %q: %w", proto.Sender, err)
}
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err)
@ -569,7 +607,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
@ -578,7 +616,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
return headeredEvent, nil return headeredEvent, nil
} }
func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, userID string) (gomatrixserverlib.FledglingEvent, bool) { func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, senderID spec.SenderID) (gomatrixserverlib.FledglingEvent, bool) {
// Work out what power level we need in order to be able to send events // Work out what power level we need in order to be able to send events
// of all types into the room. // of all types into the room.
neededPowerLevel := powerLevelContent.StateDefault neededPowerLevel := powerLevelContent.StateDefault
@ -603,8 +641,8 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC
// If the user who is upgrading the room doesn't already have sufficient // If the user who is upgrading the room doesn't already have sufficient
// power, then elevate their power levels. // power, then elevate their power levels.
if tempPowerLevelContent.UserLevel(userID) < neededPowerLevel { if tempPowerLevelContent.UserLevel(senderID) < neededPowerLevel {
tempPowerLevelContent.Users[userID] = neededPowerLevel tempPowerLevelContent.Users[string(senderID)] = neededPowerLevel
powerLevelsOverridden = true powerLevelsOverridden = true
} }

View file

@ -159,7 +159,7 @@ func (r *Queryer) QueryStateAfterEvents(
} }
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}, },
) )
@ -637,7 +637,7 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState { if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
}, },
) )
@ -975,10 +975,10 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID)
} }
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
return r.DB.GetSenderIDForUser(ctx, roomID, userID) return r.DB.GetSenderIDForUser(ctx, roomID, userID)
} }
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.DB.GetUserIDForSender(ctx, roomID, senderID)
} }

View file

@ -392,7 +392,7 @@ func TestPurgeRoom(t *testing.T) {
type fledglingEvent struct { type fledglingEvent struct {
Type string Type string
StateKey *string StateKey *string
Sender string SenderID string
RoomID string RoomID string
Redacts string Redacts string
Depth int64 Depth int64
@ -405,7 +405,7 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEve
seed := make([]byte, ed25519.SeedSize) // zero seed seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed) key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
Sender: ev.Sender, SenderID: ev.SenderID,
Type: ev.Type, Type: ev.Type,
StateKey: ev.StateKey, StateKey: ev.StateKey,
RoomID: ev.RoomID, RoomID: ev.RoomID,
@ -444,7 +444,7 @@ func TestRedaction(t *testing.T) {
builderEv := mustCreateEvent(t, fledglingEvent{ builderEv := mustCreateEvent(t, fledglingEvent{
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Sender: alice.ID, SenderID: alice.ID,
RoomID: room.ID, RoomID: room.ID,
Redacts: redactedEvent.EventID(), Redacts: redactedEvent.EventID(),
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
@ -461,7 +461,7 @@ func TestRedaction(t *testing.T) {
builderEv := mustCreateEvent(t, fledglingEvent{ builderEv := mustCreateEvent(t, fledglingEvent{
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Sender: alice.ID, SenderID: alice.ID,
RoomID: room.ID, RoomID: room.ID,
Redacts: redactedEvent.EventID(), Redacts: redactedEvent.EventID(),
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
@ -478,7 +478,7 @@ func TestRedaction(t *testing.T) {
builderEv := mustCreateEvent(t, fledglingEvent{ builderEv := mustCreateEvent(t, fledglingEvent{
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Sender: bob.ID, SenderID: bob.ID,
RoomID: room.ID, RoomID: room.ID,
Redacts: redactedEvent.EventID(), Redacts: redactedEvent.EventID(),
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,
@ -494,7 +494,7 @@ func TestRedaction(t *testing.T) {
builderEv := mustCreateEvent(t, fledglingEvent{ builderEv := mustCreateEvent(t, fledglingEvent{
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Sender: charlie.ID, SenderID: charlie.ID,
RoomID: room.ID, RoomID: room.ID,
Redacts: redactedEvent.EventID(), Redacts: redactedEvent.EventID(),
Depth: redactedEvent.Depth() + 1, Depth: redactedEvent.Depth() + 1,

View file

@ -44,7 +44,7 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
} }
type StateResolution struct { type StateResolution struct {
@ -947,7 +947,7 @@ func (v *StateResolution) resolveConflictsV1(
} }
// Resolve the conflicts. // Resolve the conflicts.
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID, senderID string) (*spec.UserID, error) { resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return v.db.GetUserIDForSender(ctx, roomID, senderID) return v.db.GetUserIDForSender(ctx, roomID, senderID)
}) })
@ -1061,7 +1061,7 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents, conflictedEvents,
nonConflictedEvents, nonConflictedEvents,
authEvents, authEvents,
func(roomID, senderID string) (*spec.UserID, error) { func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return v.db.GetUserIDForSender(ctx, roomID, senderID) return v.db.GetUserIDForSender(ctx, roomID, senderID)
}, },
) )

View file

@ -167,9 +167,9 @@ type Database interface {
// GetKnownUsers searches all users that userID knows about. // GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// GetKnownUsers tries to obtain the current mxid for a given user. // GetKnownUsers tries to obtain the current mxid for a given user.
GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
// GetKnownUsers tries to obtain the current senderID for a given user. // GetKnownUsers tries to obtain the current senderID for a given user.
GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error) GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
@ -215,7 +215,7 @@ type RoomDatabase interface {
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
} }
type EventDatabase interface { type EventDatabase interface {

View file

@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event *
var inserted bool // Did the query result in a membership change? var inserted bool // Did the query result in a membership change?
var retired []string // Did we retire any updates in the process? var retired []string // Did we retire any updates in the process?
return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.SenderID()) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, string(event.SenderID()))
if err != nil { if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
} }

View file

@ -252,6 +252,6 @@ func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, ta
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }
func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return u.d.GetUserIDForSender(ctx, roomID, senderID) return u.d.GetUserIDForSender(ctx, roomID, senderID)
} }

View file

@ -990,13 +990,13 @@ func (d *EventDatabase) MaybeRedactEvent(
// TODO: Don't hack senderID into userID here (pseudoIDs) // TODO: Don't hack senderID into userID here (pseudoIDs)
sender1Domain := "" sender1Domain := ""
sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true) sender1, err1 := spec.NewUserID(string(redactedEvent.SenderID()), true)
if err1 == nil { if err1 == nil {
sender1Domain = string(sender1.Domain()) sender1Domain = string(sender1.Domain())
} }
// TODO: Don't hack senderID into userID here (pseudoIDs) // TODO: Don't hack senderID into userID here (pseudoIDs)
sender2Domain := "" sender2Domain := ""
sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true) sender2, err2 := spec.NewUserID(string(redactionEvent.SenderID()), true)
if err2 == nil { if err2 == nil {
sender2Domain = string(sender2.Domain()) sender2Domain = string(sender2.Domain())
} }
@ -1524,14 +1524,14 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
} }
func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
// TODO: Use real logic once DB for pseudoIDs is in place // TODO: Use real logic once DB for pseudoIDs is in place
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
// TODO: Use real logic once DB for pseudoIDs is in place // TODO: Use real logic once DB for pseudoIDs is in place
return userID.String(), nil return spec.SenderID(userID.String()), nil
} }
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.

View file

@ -94,7 +94,7 @@ type MSC2836EventRelationshipsResponse struct {
func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
out := &EventRelationshipResponse{ out := &EventRelationshipResponse{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
Limited: res.Limited, Limited: res.Limited,

View file

@ -525,8 +525,8 @@ type testRoomserverAPI struct {
events map[string]*types.HeaderedEvent events map[string]*types.HeaderedEvent
} }
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
@ -590,7 +590,7 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEve
seed := make([]byte, ed25519.SeedSize) // zero seed seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed) key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
Sender: ev.Sender, SenderID: ev.Sender,
Depth: 999, Depth: 999,
Type: ev.Type, Type: ev.Type,
StateKey: ev.StateKey, StateKey: ev.StateKey,

View file

@ -730,7 +730,7 @@ func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent {
Type: ev.Type(), Type: ev.Type(),
StateKey: *ev.StateKey(), StateKey: *ev.StateKey(),
Content: ev.Content(), Content: ev.Content(),
Sender: ev.SenderID(), Sender: string(ev.SenderID()),
OriginServerTS: ev.OriginServerTS(), OriginServerTS: ev.OriginServerTS(),
} }
} }

View file

@ -523,7 +523,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
prev := types.PrevEventRef{ prev := types.PrevEventRef{
PrevContent: prevEvent.Content(), PrevContent: prevEvent.Content(),
ReplacesState: prevEvent.EventID(), ReplacesState: prevEvent.EventID(),
PrevSender: prevEvent.SenderID(), PrevSenderID: string(prevEvent.SenderID()),
} }
event.PDU, err = event.SetUnsigned(prev) event.PDU, err = event.SetUnsigned(prev)

View file

@ -193,10 +193,10 @@ func Context(
} }
} }
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
@ -204,7 +204,7 @@ func Context(
if filter.LazyLoadMembers { if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent) allEvents = append(allEvents, &requestedEvent)
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
@ -227,7 +227,7 @@ func Context(
Event: &ev, Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient, EventsBefore: eventsBeforeClient,
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
} }

View file

@ -144,7 +144,22 @@ func GetMemberships(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
res.Joined[ev.SenderID()] = joinedMember(content)
userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
if err != nil || userID == nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
res.Joined[userID.String()] = joinedMember(content)
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -153,7 +168,7 @@ func GetMemberships(
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})}, })},
} }

View file

@ -273,7 +273,7 @@ func OnIncomingMessagesRequest(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})...) })...)
} }
@ -385,7 +385,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
"events_before": len(events), "events_before": len(events),
"events_after": len(filteredEvents), "events_after": len(filteredEvents),
}).Debug("applied history visibility (messages)") }).Debug("applied history visibility (messages)")
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), start, end, err }), start, end, err
} }
@ -495,7 +495,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
} }
// Append the events ve previously retrieved locally. // Append the events ve previously retrieved locally.
events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...) events = append(events, r.snapshot.StreamEventsToEvents(r.ctx, nil, streamEvents, r.rsAPI)...)
sort.Sort(eventsByDepth(events)) sort.Sort(eventsByDepth(events))
return return

View file

@ -213,7 +213,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profile, ok := knownUsersProfiles[userID.String()] profile, ok := knownUsersProfiles[userID.String()]
if !ok { if !ok {
stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, string(ev.SenderID()))
if stateErr != nil { if stateErr != nil {
logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
continue continue
@ -239,10 +239,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
End: endToken.String(), End: endToken.String(),
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}), }),
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}), }),
ProfileInfo: profileInfos, ProfileInfo: profileInfos,
@ -263,7 +263,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}) })
} }

View file

@ -25,8 +25,8 @@ import (
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
func TestSearch(t *testing.T) { func TestSearch(t *testing.T) {

View file

@ -44,8 +44,8 @@ type DatabaseTransaction interface {
MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
@ -90,7 +90,7 @@ type DatabaseTransaction interface {
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// relevant events within the given ranges for the supplied user ID and device ID. // relevant events within the given ranges for the supplied user ID and device ID.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)

View file

@ -99,7 +99,41 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*rstypes.He
// We don't include a device here as we only include transaction IDs in // We don't include a device here as we only include transaction IDs in
// incremental syncs. // incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil
}
func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent {
out := make([]*rstypes.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].HeaderedEvent
if device != nil && in[i].TransactionID != nil {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
}
return out
} }
// AddInviteEvent stores a new invite event for a user. // AddInviteEvent stores a new invite event for a user.
@ -190,45 +224,6 @@ func (d *Database) UpsertAccountData(
return return
} }
func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent {
out := make([]*rstypes.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].HeaderedEvent
if device != nil && in[i].TransactionID != nil {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
deviceSenderID, err := d.getSenderIDForUser(in[i].RoomID(), *userID)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
}
return out
}
func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { // nolint
// TODO: Repalce with actual logic for pseudoIDs
return userID.String(), nil
}
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.

View file

@ -10,6 +10,7 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types" rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -186,7 +187,7 @@ func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]
// We don't include a device here as we only include transaction IDs in // We don't include a device here as we only include transaction IDs in
// incremental syncs. // incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil
} }
func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
@ -325,7 +326,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos(
func (d *DatabaseTransaction) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *synctypes.StateFilter, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI,
) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
@ -417,7 +418,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
if !peek.Deleted { if !peek.Deleted {
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: spec.Peek, Membership: spec.Peek,
StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), StateEvents: d.StreamEventsToEvents(ctx, device, state[peek.RoomID], rsAPI),
RoomID: peek.RoomID, RoomID: peek.RoomID,
}) })
} }
@ -462,7 +463,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: membership, Membership: membership,
MembershipPos: ev.StreamPosition, MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[roomID], rsAPI),
RoomID: roomID, RoomID: roomID,
}) })
break break
@ -474,7 +475,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: spec.Join, Membership: spec.Join,
StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[joinedRoomID], rsAPI),
RoomID: joinedRoomID, RoomID: joinedRoomID,
NewlyJoined: newlyJoinedRooms[joinedRoomID], NewlyJoined: newlyJoinedRooms[joinedRoomID],
}) })
@ -490,7 +491,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *synctypes.StateFilter, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI,
) ([]types.StateDelta, []string, error) { ) ([]types.StateDelta, []string, error) {
// Look up all memberships for the user. We only care about rooms that a // Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left. // user has ever interacted with — joined to, kicked/banned from, left.
@ -531,7 +532,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
} }
deltas[peek.RoomID] = types.StateDelta{ deltas[peek.RoomID] = types.StateDelta{
Membership: spec.Peek, Membership: spec.Peek,
StateEvents: d.StreamEventsToEvents(device, s), StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI),
RoomID: peek.RoomID, RoomID: peek.RoomID,
} }
} }
@ -560,7 +561,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
deltas[roomID] = types.StateDelta{ deltas[roomID] = types.StateDelta{
Membership: membership, Membership: membership,
MembershipPos: ev.StreamPosition, MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), StateEvents: d.StreamEventsToEvents(ctx, device, stateStreamEvents, rsAPI),
RoomID: roomID, RoomID: roomID,
} }
} }
@ -581,7 +582,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
} }
deltas[joinedRoomID] = types.StateDelta{ deltas[joinedRoomID] = types.StateDelta{
Membership: spec.Join, Membership: spec.Join,
StateEvents: d.StreamEventsToEvents(device, s), StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI),
RoomID: joinedRoomID, RoomID: joinedRoomID,
} }
} }

View file

@ -214,7 +214,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
} }
gots := snapshot.StreamEventsToEvents(nil, paginatedEvents) gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
}) })
}) })

View file

@ -175,12 +175,12 @@ func (p *PDUStreamProvider) IncrementalSync(
eventFilter := req.Filter.Room.Timeline eventFilter := req.Filter.Room.Timeline
if req.WantFullState { if req.WantFullState {
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
return from return from
} }
} else { } else {
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return from return from
} }
@ -275,7 +275,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
limited := dbEvents[delta.RoomID].Limited limited := dbEvents[delta.RoomID].Limited
recEvents := gomatrixserverlib.ReverseTopologicalOrdering( recEvents := gomatrixserverlib.ReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(device, recentStreamEvents)), gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)),
gomatrixserverlib.TopologicalOrderByPrevEvents, gomatrixserverlib.TopologicalOrderByPrevEvents,
) )
recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents)) recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents))
@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
} }
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Join[delta.RoomID] = jr req.Response.Rooms.Join[delta.RoomID] = jr
@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms // TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Peek[delta.RoomID] = jr req.Response.Rooms.Peek[delta.RoomID] = jr
@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case spec.Ban: case spec.Ban:
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Leave[delta.RoomID] = lr req.Response.Rooms.Leave[delta.RoomID] = lr
@ -437,7 +437,7 @@ func applyHistoryVisibilityFilter(
for _, ev := range recentEvents { for _, ev := range recentEvents {
if ev.StateKey() != nil { if ev.StateKey() != nil {
stateTypes = append(stateTypes, ev.Type()) stateTypes = append(stateTypes, ev.Type())
senders = append(senders, ev.SenderID()) senders = append(senders, string(ev.SenderID()))
} }
} }
@ -512,7 +512,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check. // "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) recentEvents := snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)
events := recentEvents events := recentEvents
// Only apply history visibility checks if the response is for joined rooms // Only apply history visibility checks if the response is for joined rooms
@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
return jr, nil return jr, nil
@ -593,8 +593,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// Add all users the client doesn't know about yet to a list // Add all users the client doesn't know about yet to a list
for _, event := range timelineEvents { for _, event := range timelineEvents {
// Membership is not yet cached, add it to the list // Membership is not yet cached, add it to the list
if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.SenderID()); !ok { if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, string(event.SenderID())); !ok {
timelineUsers[event.SenderID()] = struct{}{} timelineUsers[string(event.SenderID())] = struct{}{}
} }
} }
// Preallocate with the same amount, even if it will end up with fewer values // Preallocate with the same amount, even if it will end up with fewer values

View file

@ -40,8 +40,8 @@ type syncRoomserverAPI struct {
rooms []*test.Room rooms []*test.Room
} }
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {

View file

@ -343,7 +343,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
type PrevEventRef struct { type PrevEventRef struct {
PrevContent json.RawMessage `json:"prev_content"` PrevContent json.RawMessage `json:"prev_content"`
ReplacesState string `json:"replaces_state"` ReplacesState string `json:"replaces_state"`
PrevSender string `json:"prev_sender"` PrevSenderID string `json:"prev_sender"`
} }
type DeviceLists struct { type DeviceLists struct {

View file

@ -39,8 +39,8 @@ var (
roomIDCounter = int64(0) roomIDCounter = int64(0)
) )
func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
type Room struct { type Room struct {
@ -168,7 +168,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
} }
builder := gomatrixserverlib.MustGetRoomVersion(r.Version).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ builder := gomatrixserverlib.MustGetRoomVersion(r.Version).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
Sender: creator.ID, SenderID: creator.ID,
RoomID: r.ID, RoomID: r.ID,
Type: eventType, Type: eventType,
StateKey: mod.stateKey, StateKey: mod.stateKey,

View file

@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
} }
if s.cfg.Matrix.ReportStats.Enabled { if s.cfg.Matrix.ReportStats.Enabled {
go s.storeMessageStats(ctx, event.Type(), event.SenderID(), event.RoomID()) go s.storeMessageStats(ctx, event.Type(), string(event.SenderID()), event.RoomID())
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -664,7 +664,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize, roomSize: roomSize,
} }
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) { rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {
@ -698,7 +698,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display
func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
func (rse *ruleSetEvalContext) HasPowerLevel(senderID, levelKey string) (bool, error) { func (rse *ruleSetEvalContext) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) {
req := &rsapi.QueryLatestEventsAndStateRequest{ req := &rsapi.QueryLatestEventsAndStateRequest{
RoomID: rse.roomID, RoomID: rse.roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{ StateToFetch: []gomatrixserverlib.StateKeyTuple{

View file

@ -47,8 +47,8 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(senderID, true) return spec.NewUserID(string(senderID), true)
} }
func Test_evaluatePushRules(t *testing.T) { func Test_evaluatePushRules(t *testing.T) {