Reduce senderID code duplication

This commit is contained in:
Devon Hudson 2023-06-12 09:46:31 +01:00
parent 8fb8d5a743
commit 9e3e4afdb7
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
10 changed files with 48 additions and 62 deletions

View file

@ -225,17 +225,7 @@ func createRoom(
EventTime: evTime, EventTime: evTime,
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID.String(), *userID) roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, roomserverAPI.SenderUserIDPair{
SenderID: senderID, UserID: *userID,
}, *roomID, &req)
if createRes != nil { if createRes != nil {
return *createRes return *createRes
} }

View file

@ -91,7 +91,7 @@ func SendRedaction(
// "Users may redact their own events, and any user with a power level greater than or equal // "Users may redact their own events, and any user with a power level greater than or equal
// to the redact power level of the room may redact events there" // to the redact power level of the room may redact events there"
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey allowedToRedact := ev.SenderID() == senderID
if !allowedToRedact { if !allowedToRedact {
plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,

View file

@ -67,17 +67,7 @@ func UpgradeRoom(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *userID) newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, *userID, gomatrixserverlib.RoomVersion(r.NewVersion))
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("Failed getting senderID for user")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, roomserverAPI.SenderUserIDPair{
SenderID: senderID, UserID: *userID,
}, gomatrixserverlib.RoomVersion(r.NewVersion))
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
case roomserverAPI.ErrNotAllowed: case roomserverAPI.ErrNotAllowed:

View file

@ -189,9 +189,9 @@ type ClientRoomserverAPI interface {
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
PerformCreateRoom(ctx context.Context, user SenderUserIDPair, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
// PerformRoomUpgrade upgrades a room to a newer version // PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, roomID string, user SenderUserIDPair, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, roomID string) error PerformAdminPurgeRoom(ctx context.Context, roomID string) error

View file

@ -11,11 +11,6 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type SenderUserIDPair struct {
SenderID spec.SenderID
UserID spec.UserID
}
type PerformCreateRoomRequest struct { type PerformCreateRoomRequest struct {
InvitedUsers []string InvitedUsers []string
RoomName string RoomName string

View file

@ -235,9 +235,9 @@ func (r *RoomserverInternalAPI) HandleInvite(
} }
func (r *RoomserverInternalAPI) PerformCreateRoom( func (r *RoomserverInternalAPI) PerformCreateRoom(
ctx context.Context, user api.SenderUserIDPair, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest, ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest,
) (string, *util.JSONResponse) { ) (string, *util.JSONResponse) {
return r.Creator.PerformCreateRoom(ctx, user, roomID, createRequest) return r.Creator.PerformCreateRoom(ctx, userID, roomID, createRequest)
} }
func (r *RoomserverInternalAPI) PerformInvite( func (r *RoomserverInternalAPI) PerformInvite(

View file

@ -44,7 +44,7 @@ type Creator struct {
// PerformCreateRoom handles all the steps necessary to create a new room. // PerformCreateRoom handles all the steps necessary to create a new room.
// nolint: gocyclo // nolint: gocyclo
func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPair, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) { func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) {
verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion)
if err != nil { if err != nil {
return "", &util.JSONResponse{ return "", &util.JSONResponse{
@ -63,9 +63,17 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
} }
} }
} }
createContent["creator"] = user.SenderID senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
createContent["creator"] = senderID
createContent["room_version"] = createRequest.RoomVersion createContent["room_version"] = createRequest.RoomVersion
powerLevelContent := eventutil.InitialPowerLevelsContent(string(user.SenderID)) powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID))
joinRuleContent := gomatrixserverlib.JoinRuleContent{ joinRuleContent := gomatrixserverlib.JoinRuleContent{
JoinRule: spec.Invite, JoinRule: spec.Invite,
} }
@ -121,7 +129,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
} }
membershipEvent := gomatrixserverlib.FledglingEvent{ membershipEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: string(user.SenderID), StateKey: string(senderID),
Content: gomatrixserverlib.MemberContent{ Content: gomatrixserverlib.MemberContent{
Membership: spec.Join, Membership: spec.Join,
DisplayName: createRequest.UserDisplayName, DisplayName: createRequest.UserDisplayName,
@ -163,7 +171,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
var roomAlias string var roomAlias string
if createRequest.RoomAliasName != "" { if createRequest.RoomAliasName != "" {
roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, user.UserID.Domain()) roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, userID.Domain())
// check it's free // check it's free
// TODO: This races but is better than nothing // TODO: This races but is better than nothing
hasAliasReq := api.GetRoomIDForAliasRequest{ hasAliasReq := api.GetRoomIDForAliasRequest{
@ -281,7 +289,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
depth := i + 1 // depth starts at 1 depth := i + 1 // depth starts at 1
builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
SenderID: string(user.SenderID), SenderID: string(senderID),
RoomID: roomID.String(), RoomID: roomID.String(),
Type: e.Type, Type: e.Type,
StateKey: &e.StateKey, StateKey: &e.StateKey,
@ -306,7 +314,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
ev, err = builder.Build(createRequest.EventTime, user.UserID.Domain(), createRequest.KeyID, createRequest.PrivateKey) ev, err = builder.Build(createRequest.EventTime, userID.Domain(), createRequest.KeyID, createRequest.PrivateKey)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("buildEvent failed") util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{
@ -342,11 +350,11 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
inputs = append(inputs, api.InputRoomEvent{ inputs = append(inputs, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: user.UserID.Domain(), Origin: userID.Domain(),
SendAsServer: api.DoNotSendToOtherServers, SendAsServer: api.DoNotSendToOtherServers,
}) })
} }
if err = api.SendInputRoomEvents(ctx, c.RSAPI, user.UserID.Domain(), inputs, false); err != nil { if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs, false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -361,7 +369,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
aliasReq := api.SetRoomAliasRequest{ aliasReq := api.SetRoomAliasRequest{
Alias: roomAlias, Alias: roomAlias,
RoomID: roomID.String(), RoomID: roomID.String(),
UserID: user.UserID.String(), UserID: userID.String(),
} }
var aliasResp api.SetRoomAliasResponse var aliasResp api.SetRoomAliasResponse
@ -434,7 +442,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
} }
inviteeString := string(inviteeSenderID) inviteeString := string(inviteeSenderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: string(user.SenderID), SenderID: string(senderID),
RoomID: roomID.String(), RoomID: roomID.String(),
Type: "m.room.member", Type: "m.room.member",
StateKey: &inviteeString, StateKey: &inviteeString,
@ -457,7 +465,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
// Build the invite event. // Build the invite event.
identity := &fclient.SigningIdentity{ identity := &fclient.SigningIdentity{
ServerName: user.UserID.Domain(), ServerName: userID.Domain(),
KeyID: createRequest.KeyID, KeyID: createRequest.KeyID,
PrivateKey: createRequest.PrivateKey, PrivateKey: createRequest.PrivateKey,
} }
@ -477,7 +485,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, user api.SenderUserIDPa
Event: event, Event: event,
InviteRoomState: inviteStrippedState, InviteRoomState: inviteStrippedState,
RoomVersion: event.Version(), RoomVersion: event.Version(),
SendAsServer: string(user.UserID.Domain()), SendAsServer: string(userID.Domain()),
}) })
switch e := err.(type) { switch e := err.(type) {
case api.ErrInvalidID: case api.ErrInvalidID:

View file

@ -91,9 +91,8 @@ func (r *Leaver) performLeaveRoomByID(
if serr != nil || sender == nil { if serr != nil || sender == nil {
return nil, fmt.Errorf("sender %q has no matching userID", senderUser) return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
} }
inviteSender := api.SenderUserIDPair{SenderID: senderUser, UserID: *sender}
if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
return r.performFederatedRejectInvite(ctx, req, res, inviteSender.UserID, eventID, leaver) return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, leaver)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{} accData := &userapi.QueryAccountDataResponse{}

View file

@ -38,14 +38,14 @@ type Upgrader struct {
// PerformRoomUpgrade upgrades a room from one version to another // PerformRoomUpgrade upgrades a room from one version to another
func (r *Upgrader) PerformRoomUpgrade( func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context, ctx context.Context,
roomID string, user api.SenderUserIDPair, roomVersion gomatrixserverlib.RoomVersion, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion,
) (newRoomID string, err error) { ) (newRoomID string, err error) {
return r.performRoomUpgrade(ctx, roomID, user, roomVersion) return r.performRoomUpgrade(ctx, roomID, userID, roomVersion)
} }
func (r *Upgrader) performRoomUpgrade( func (r *Upgrader) performRoomUpgrade(
ctx context.Context, ctx context.Context,
roomID string, user api.SenderUserIDPair, roomVersion gomatrixserverlib.RoomVersion, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion,
) (string, error) { ) (string, error) {
evTime := time.Now() evTime := time.Now()
@ -54,14 +54,20 @@ func (r *Upgrader) performRoomUpgrade(
return "", err return "", err
} }
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
return "", err
}
// 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone)
if !r.userIsAuthorized(ctx, user.SenderID, roomID) { if !r.userIsAuthorized(ctx, senderID, roomID) {
return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")} return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")}
} }
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs? // probably shouldn't be using pseudo-random strings, maybe GUIDs?
newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), user.UserID.Domain()) newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain())
// Get the existing room state for the old room. // Get the existing room state for the old room.
oldRoomReq := &api.QueryLatestEventsAndStateRequest{ oldRoomReq := &api.QueryLatestEventsAndStateRequest{
@ -73,25 +79,25 @@ func (r *Upgrader) performRoomUpgrade(
} }
// Make the tombstone event // Make the tombstone event
tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, user.SenderID, user.UserID.Domain(), roomID, newRoomID) tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, senderID, userID.Domain(), roomID, newRoomID)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Generate the initial events we need to send into the new room. This includes copied state events and bans // Generate the initial events we need to send into the new room. This includes copied state events and bans
// as well as the power level events needed to set up the room // as well as the power level events needed to set up the room
eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, user.SenderID, roomID, roomVersion, tombstoneEvent) eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, senderID, roomID, roomVersion, tombstoneEvent)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Send the setup events to the new room // Send the setup events to the new room
if pErr = r.sendInitialEvents(ctx, evTime, user.SenderID, user.UserID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil { if pErr = r.sendInitialEvents(ctx, evTime, senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil {
return "", pErr return "", pErr
} }
// 5. Send the tombstone event to the old room // 5. Send the tombstone event to the old room
if pErr = r.sendHeaderedEvent(ctx, user.UserID.Domain(), tombstoneEvent, string(user.UserID.Domain())); pErr != nil { if pErr = r.sendHeaderedEvent(ctx, userID.Domain(), tombstoneEvent, string(userID.Domain())); pErr != nil {
return "", pErr return "", pErr
} }
@ -101,17 +107,17 @@ func (r *Upgrader) performRoomUpgrade(
} }
// If the old room had a canonical alias event, it should be deleted in the old room // If the old room had a canonical alias event, it should be deleted in the old room
if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, user.SenderID, user.UserID.Domain(), roomID); pErr != nil { if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, senderID, userID.Domain(), roomID); pErr != nil {
return "", pErr return "", pErr
} }
// 4. Move local aliases to the new room // 4. Move local aliases to the new room
if pErr = moveLocalAliases(ctx, roomID, newRoomID, user.SenderID, user.UserID, r.URSAPI); pErr != nil { if pErr = moveLocalAliases(ctx, roomID, newRoomID, senderID, userID, r.URSAPI); pErr != nil {
return "", pErr return "", pErr
} }
// 6. Restrict power levels in the old room // 6. Restrict power levels in the old room
if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, user.SenderID, user.UserID.Domain(), roomID); pErr != nil { if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, senderID, userID.Domain(), roomID); pErr != nil {
return "", pErr return "", pErr
} }

View file

@ -1042,9 +1042,7 @@ func TestUpgrade(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("upgrade userID is invalid") t.Fatalf("upgrade userID is invalid")
} }
newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, *userID, version.DefaultRoomVersion())
api.SenderUserIDPair{SenderID: spec.SenderID(tc.upgradeUser), UserID: *userID},
version.DefaultRoomVersion())
if err != nil && tc.wantNewRoom { if err != nil && tc.wantNewRoom {
t.Fatal(err) t.Fatal(err)
} }