Fix statekey usage in roomserver/perform_upgrade & create_room

This commit is contained in:
Devon Hudson 2023-06-08 15:22:13 -06:00
parent d3a0101eb5
commit 58d3452b7f
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
7 changed files with 88 additions and 110 deletions

View file

@ -224,7 +224,18 @@ func createRoom(
PrivateKey: privateKey, PrivateKey: privateKey,
EventTime: evTime, EventTime: evTime,
} }
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID.String(), *userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, roomserverAPI.SenderUserIDPair{
SenderID: senderID, UserID: *userID,
}, *roomID, &req)
if createRes != nil { if createRes != nil {
return *createRes return *createRes
} }

View file

@ -59,7 +59,25 @@ func UpgradeRoom(
} }
} }
newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, device.UserID, gomatrixserverlib.RoomVersion(r.NewVersion)) userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("device UserID is invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("Failed getting senderID for user")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, roomserverAPI.SenderUserIDPair{
SenderID: senderID, UserID: *userID,
}, gomatrixserverlib.RoomVersion(r.NewVersion))
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
case roomserverAPI.ErrNotAllowed: case roomserverAPI.ErrNotAllowed:

View file

@ -189,9 +189,9 @@ type ClientRoomserverAPI interface {
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) PerformCreateRoom(ctx context.Context, user SenderUserIDPair, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
// PerformRoomUpgrade upgrades a room to a newer version // PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) PerformRoomUpgrade(ctx context.Context, roomID string, user SenderUserIDPair, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, roomID string) error PerformAdminPurgeRoom(ctx context.Context, roomID string) error

View file

@ -235,9 +235,9 @@ func (r *RoomserverInternalAPI) HandleInvite(
} }
func (r *RoomserverInternalAPI) PerformCreateRoom( func (r *RoomserverInternalAPI) PerformCreateRoom(
ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest, ctx context.Context, user api.SenderUserIDPair, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest,
) (string, *util.JSONResponse) { ) (string, *util.JSONResponse) {
return r.Creator.PerformCreateRoom(ctx, userID, roomID, createRequest) return r.Creator.PerformCreateRoom(ctx, user, roomID, createRequest)
} }
func (r *RoomserverInternalAPI) PerformInvite( func (r *RoomserverInternalAPI) PerformInvite(

View file

@ -44,7 +44,7 @@ type Creator struct {
// PerformCreateRoom handles all the steps necessary to create a new room. // PerformCreateRoom handles all the steps necessary to create a new room.
// nolint: gocyclo // nolint: gocyclo
func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) { func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPair, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) {
verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion)
if err != nil { if err != nil {
return "", &util.JSONResponse{ return "", &util.JSONResponse{
@ -63,10 +63,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
} }
senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID) createContent["creator"] = user.SenderID
createContent["creator"] = senderID
createContent["room_version"] = createRequest.RoomVersion createContent["room_version"] = createRequest.RoomVersion
powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID)) powerLevelContent := eventutil.InitialPowerLevelsContent(string(user.SenderID))
joinRuleContent := gomatrixserverlib.JoinRuleContent{ joinRuleContent := gomatrixserverlib.JoinRuleContent{
JoinRule: spec.Invite, JoinRule: spec.Invite,
} }
@ -122,7 +121,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
membershipEvent := gomatrixserverlib.FledglingEvent{ membershipEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: string(senderID), StateKey: string(user.SenderID),
Content: gomatrixserverlib.MemberContent{ Content: gomatrixserverlib.MemberContent{
Membership: spec.Join, Membership: spec.Join,
DisplayName: createRequest.UserDisplayName, DisplayName: createRequest.UserDisplayName,
@ -164,7 +163,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
var roomAlias string var roomAlias string
if createRequest.RoomAliasName != "" { if createRequest.RoomAliasName != "" {
roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, userID.Domain()) roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, user.UserID.Domain())
// check it's free // check it's free
// TODO: This races but is better than nothing // TODO: This races but is better than nothing
hasAliasReq := api.GetRoomIDForAliasRequest{ hasAliasReq := api.GetRoomIDForAliasRequest{
@ -282,7 +281,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(user.SenderID),
RoomID: roomID.String(), RoomID: roomID.String(),
Type: e.Type, Type: e.Type,
StateKey: &e.StateKey, StateKey: &e.StateKey,
@ -307,7 +306,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
ev, err = builder.Build(createRequest.EventTime, userID.Domain(), createRequest.KeyID, createRequest.PrivateKey) ev, err = builder.Build(createRequest.EventTime, user.UserID.Domain(), createRequest.KeyID, createRequest.PrivateKey)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("buildEvent failed") util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{
@ -343,11 +342,11 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
inputs = append(inputs, api.InputRoomEvent{ inputs = append(inputs, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: userID.Domain(), Origin: user.UserID.Domain(),
SendAsServer: api.DoNotSendToOtherServers, SendAsServer: api.DoNotSendToOtherServers,
}) })
} }
if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs, false); err != nil { if err = api.SendInputRoomEvents(ctx, c.RSAPI, user.UserID.Domain(), inputs, false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -362,7 +361,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
aliasReq := api.SetRoomAliasRequest{ aliasReq := api.SetRoomAliasRequest{
Alias: roomAlias, Alias: roomAlias,
RoomID: roomID.String(), RoomID: roomID.String(),
UserID: userID.String(), UserID: user.UserID.String(),
} }
var aliasResp api.SetRoomAliasResponse var aliasResp api.SetRoomAliasResponse
@ -435,7 +434,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
inviteeString := string(inviteeSenderID) inviteeString := string(inviteeSenderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(user.SenderID),
RoomID: roomID.String(), RoomID: roomID.String(),
Type: "m.room.member", Type: "m.room.member",
StateKey: &inviteeString, StateKey: &inviteeString,
@ -458,7 +457,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
// Build the invite event. // Build the invite event.
identity := &fclient.SigningIdentity{ identity := &fclient.SigningIdentity{
ServerName: userID.Domain(), ServerName: user.UserID.Domain(),
KeyID: createRequest.KeyID, KeyID: createRequest.KeyID,
PrivateKey: createRequest.PrivateKey, PrivateKey: createRequest.PrivateKey,
} }
@ -478,7 +477,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
Event: event, Event: event,
InviteRoomState: inviteStrippedState, InviteRoomState: inviteStrippedState,
RoomVersion: event.Version(), RoomVersion: event.Version(),
SendAsServer: string(userID.Domain()), SendAsServer: string(user.UserID.Domain()),
}) })
switch e := err.(type) { switch e := err.(type) {
case api.ErrInvalidID: case api.ErrInvalidID:

View file

@ -38,19 +38,15 @@ type Upgrader struct {
// PerformRoomUpgrade upgrades a room from one version to another // PerformRoomUpgrade upgrades a room from one version to another
func (r *Upgrader) PerformRoomUpgrade( func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context, ctx context.Context,
roomID, userID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, user api.SenderUserIDPair, roomVersion gomatrixserverlib.RoomVersion,
) (newRoomID string, err error) { ) (newRoomID string, err error) {
return r.performRoomUpgrade(ctx, roomID, userID, roomVersion) return r.performRoomUpgrade(ctx, roomID, user, roomVersion)
} }
func (r *Upgrader) performRoomUpgrade( func (r *Upgrader) performRoomUpgrade(
ctx context.Context, ctx context.Context,
roomID, userID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, user api.SenderUserIDPair, roomVersion gomatrixserverlib.RoomVersion,
) (string, error) { ) (string, error) {
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return "", api.ErrNotAllowed{Err: fmt.Errorf("error validating the user ID")}
}
evTime := time.Now() evTime := time.Now()
// Return an immediate error if the room does not exist // Return an immediate error if the room does not exist
@ -59,13 +55,13 @@ func (r *Upgrader) performRoomUpgrade(
} }
// 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone)
if !r.userIsAuthorized(ctx, userID, roomID) { if !r.userIsAuthorized(ctx, user.SenderID, roomID) {
return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")} return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")}
} }
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs? // probably shouldn't be using pseudo-random strings, maybe GUIDs?
newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), user.UserID.Domain())
// Get the existing room state for the old room. // Get the existing room state for the old room.
oldRoomReq := &api.QueryLatestEventsAndStateRequest{ oldRoomReq := &api.QueryLatestEventsAndStateRequest{
@ -77,25 +73,25 @@ func (r *Upgrader) performRoomUpgrade(
} }
// Make the tombstone event // Make the tombstone event
tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, userID, roomID, newRoomID) tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, user.SenderID, user.UserID.Domain(), roomID, newRoomID)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Generate the initial events we need to send into the new room. This includes copied state events and bans // Generate the initial events we need to send into the new room. This includes copied state events and bans
// as well as the power level events needed to set up the room // as well as the power level events needed to set up the room
eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, roomVersion, tombstoneEvent) eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, user.SenderID, roomID, roomVersion, tombstoneEvent)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Send the setup events to the new room // Send the setup events to the new room
if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, roomVersion, eventsToMake); pErr != nil { if pErr = r.sendInitialEvents(ctx, evTime, user.SenderID, user.UserID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil {
return "", pErr return "", pErr
} }
// 5. Send the tombstone event to the old room // 5. Send the tombstone event to the old room
if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil { if pErr = r.sendHeaderedEvent(ctx, user.UserID.Domain(), tombstoneEvent, string(user.UserID.Domain())); pErr != nil {
return "", pErr return "", pErr
} }
@ -105,17 +101,17 @@ func (r *Upgrader) performRoomUpgrade(
} }
// If the old room had a canonical alias event, it should be deleted in the old room // If the old room had a canonical alias event, it should be deleted in the old room
if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil { if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, user.SenderID, user.UserID.Domain(), roomID); pErr != nil {
return "", pErr return "", pErr
} }
// 4. Move local aliases to the new room // 4. Move local aliases to the new room
if pErr = moveLocalAliases(ctx, roomID, newRoomID, userID, r.URSAPI); pErr != nil { if pErr = moveLocalAliases(ctx, roomID, newRoomID, user.SenderID, user.UserID, r.URSAPI); pErr != nil {
return "", pErr return "", pErr
} }
// 6. Restrict power levels in the old room // 6. Restrict power levels in the old room
if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil { if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, user.SenderID, user.UserID.Domain(), roomID); pErr != nil {
return "", pErr return "", pErr
} }
@ -130,7 +126,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma
return oldPowerLevelsEvent.PowerLevels() return oldPowerLevelsEvent.PowerLevels()
} }
func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error { func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error {
restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID)
if pErr != nil { if pErr != nil {
return pErr return pErr
@ -147,7 +143,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel
restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel
restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{
Type: spec.MRoomPowerLevels, Type: spec.MRoomPowerLevels,
StateKey: "", StateKey: "",
Content: restrictedPowerLevelContent, Content: restrictedPowerLevelContent,
@ -165,7 +161,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
} }
func moveLocalAliases(ctx context.Context, func moveLocalAliases(ctx context.Context,
roomID, newRoomID, userID string, roomID, newRoomID string, senderID spec.SenderID, userID spec.UserID,
URSAPI api.RoomserverInternalAPI, URSAPI api.RoomserverInternalAPI,
) (err error) { ) (err error) {
@ -175,14 +171,6 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to get old room aliases: %w", err) return fmt.Errorf("Failed to get old room aliases: %w", err)
} }
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return fmt.Errorf("Failed to get userID: %w", err)
}
senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return fmt.Errorf("Failed to get senderID: %w", err)
}
for _, alias := range aliasRes.Aliases { for _, alias := range aliasRes.Aliases {
removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias} removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{} removeAliasRes := api.RemoveRoomAliasResponse{}
@ -190,7 +178,7 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to remove old room alias: %w", err) return fmt.Errorf("Failed to remove old room alias: %w", err)
} }
setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} setAliasReq := api.SetRoomAliasRequest{UserID: userID.String(), Alias: alias, RoomID: newRoomID}
setAliasRes := api.SetRoomAliasResponse{} setAliasRes := api.SetRoomAliasResponse{}
if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil {
return fmt.Errorf("Failed to set new room alias: %w", err) return fmt.Errorf("Failed to set new room alias: %w", err)
@ -199,7 +187,7 @@ func moveLocalAliases(ctx context.Context,
return nil return nil
} }
func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error { func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error {
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") { if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") {
continue continue
@ -217,7 +205,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
} }
} }
emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{
Type: spec.MRoomCanonicalAlias, Type: spec.MRoomCanonicalAlias,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
}) })
@ -280,7 +268,7 @@ func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error
return nil return nil
} }
func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, func (r *Upgrader) userIsAuthorized(ctx context.Context, senderID spec.SenderID, roomID string,
) bool { ) bool {
plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,
@ -295,26 +283,18 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
} }
// Check for power level required to send tombstone event (marks the current room as obsolete), // Check for power level required to send tombstone event (marks the current room as obsolete),
// if not found, use the StateDefault power level // if not found, use the StateDefault power level
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return false
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return false
}
return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true) return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true)
} }
// nolint:gocyclo // nolint:gocyclo
func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) { func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, senderID spec.SenderID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) {
state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents)) state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents))
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.StateKey() == nil { if event.StateKey() == nil {
// This shouldn't ever happen, but better to be safe than sorry. // This shouldn't ever happen, but better to be safe than sorry.
continue continue
} }
if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) { if event.Type() == spec.MRoomMember && !event.StateKeyEquals(string(senderID)) {
// With the exception of bans which we do want to copy, we // With the exception of bans which we do want to copy, we
// should ignore membership events that aren't our own, as event auth will // should ignore membership events that aren't our own, as event auth will
// prevent us from being able to create membership events on behalf of other // prevent us from being able to create membership events on behalf of other
@ -341,10 +321,10 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// The following events are ones that we are going to override manually // The following events are ones that we are going to override manually
// in the following section. // in the following section.
override := map[gomatrixserverlib.StateKeyTuple]struct{}{ override := map[gomatrixserverlib.StateKeyTuple]struct{}{
{EventType: spec.MRoomCreate, StateKey: ""}: {}, {EventType: spec.MRoomCreate, StateKey: ""}: {},
{EventType: spec.MRoomMember, StateKey: userID}: {}, {EventType: spec.MRoomMember, StateKey: string(senderID)}: {},
{EventType: spec.MRoomPowerLevels, StateKey: ""}: {}, {EventType: spec.MRoomPowerLevels, StateKey: ""}: {},
{EventType: spec.MRoomJoinRules, StateKey: ""}: {}, {EventType: spec.MRoomJoinRules, StateKey: ""}: {},
} }
// The overridden events are essential events that must be present in the // The overridden events are essential events that must be present in the
@ -356,7 +336,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
} }
oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}] oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}]
oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: userID}] oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: string(senderID)}]
oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}] oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}]
oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}] oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}]
@ -365,7 +345,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// in the create event (such as for the room types MSC). // in the create event (such as for the room types MSC).
newCreateContent := map[string]interface{}{} newCreateContent := map[string]interface{}{}
_ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent) _ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent)
newCreateContent["creator"] = userID newCreateContent["creator"] = string(senderID)
newCreateContent["room_version"] = newVersion newCreateContent["room_version"] = newVersion
newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{ newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{
EventID: tombstoneEvent.EventID(), EventID: tombstoneEvent.EventID(),
@ -386,7 +366,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
newMembershipContent["membership"] = spec.Join newMembershipContent["membership"] = spec.Join
newMembershipEvent := gomatrixserverlib.FledglingEvent{ newMembershipEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: userID, StateKey: string(senderID),
Content: newMembershipContent, Content: newMembershipContent,
} }
@ -401,14 +381,6 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return nil, fmt.Errorf("Power level event content was invalid") return nil, fmt.Errorf("Power level event content was invalid")
} }
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID) tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID)
// Now do the join rules event, same as the create and membership // Now do the join rules event, same as the create and membership
@ -471,21 +443,13 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return eventsToMake, nil return eventsToMake, nil
} }
func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error {
var err error var err error
var builtEvents []*types.HeaderedEvent var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
for i, e := range eventsToMake { for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
fullUserID, userIDErr := spec.NewUserID(userID, true)
if userIDErr != nil {
return userIDErr
}
senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID)
if queryErr != nil {
return queryErr
}
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(senderID),
RoomID: newRoomID, RoomID: newRoomID,
@ -550,7 +514,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
func (r *Upgrader) makeTombstoneEvent( func (r *Upgrader) makeTombstoneEvent(
ctx context.Context, ctx context.Context,
evTime time.Time, evTime time.Time,
userID, roomID, newRoomID string, senderID spec.SenderID, senderDomain spec.ServerName, roomID, newRoomID string,
) (*types.HeaderedEvent, error) { ) (*types.HeaderedEvent, error) {
content := map[string]interface{}{ content := map[string]interface{}{
"body": "This room has been replaced", "body": "This room has been replaced",
@ -560,30 +524,21 @@ func (r *Upgrader) makeTombstoneEvent(
Type: "m.room.tombstone", Type: "m.room.tombstone",
Content: content, Content: content,
} }
return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event) return r.makeHeaderedEvent(ctx, evTime, senderID, senderDomain, roomID, event)
} }
func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, senderID spec.SenderID, senderDomain spec.ServerName, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) {
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
if err != nil {
return nil, err
}
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(senderID),
RoomID: roomID, RoomID: roomID,
Type: event.Type, Type: event.Type,
StateKey: &event.StateKey, StateKey: &event.StateKey,
} }
err = proto.SetContent(event.Content) err := proto.SetContent(event.Content)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err) return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err)
} }
// Get the sender domain. // Get the sender domain.
senderDomain := fullUserID.Domain()
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err)

View file

@ -821,17 +821,6 @@ func TestUpgrade(t *testing.T) {
validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI)
wantNewRoom bool wantNewRoom bool
}{ }{
{
name: "invalid userID",
upgradeUser: "!notvalid:test",
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
room := test.NewRoom(t, alice)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
return room.ID
},
},
{ {
name: "invalid roomID", name: "invalid roomID",
upgradeUser: alice.ID, upgradeUser: alice.ID,
@ -1049,7 +1038,13 @@ func TestUpgrade(t *testing.T) {
} }
roomID := tc.roomFunc(rsAPI) roomID := tc.roomFunc(rsAPI)
newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, tc.upgradeUser, version.DefaultRoomVersion()) userID, err := spec.NewUserID(tc.upgradeUser, true)
if err != nil {
t.Fatalf("upgrade userID is invalid")
}
newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID,
api.SenderUserIDPair{SenderID: spec.SenderID(tc.upgradeUser), UserID: *userID},
version.DefaultRoomVersion())
if err != nil && tc.wantNewRoom { if err != nil && tc.wantNewRoom {
t.Fatal(err) t.Fatal(err)
} }