From a49c9f01e227aeb12aa2f27d5bf1915453c23a3b Mon Sep 17 00:00:00 2001 From: devonh Date: Mon, 8 May 2023 19:25:44 +0000 Subject: [PATCH 01/35] Only require room version instead of room info for db.Events() (#3079) This reduces the API requirements for the Events database to align with what is actually required. --- cmd/resolve-state/main.go | 4 ++-- roomserver/internal/helpers/auth.go | 7 +++++- roomserver/internal/helpers/helpers.go | 12 +++++++--- roomserver/internal/input/input_events.go | 5 ++++- roomserver/internal/input/input_membership.go | 2 +- roomserver/internal/input/input_missing.go | 5 ++++- roomserver/internal/perform/perform_admin.go | 2 +- .../internal/perform/perform_backfill.go | 9 +++++--- roomserver/internal/perform/perform_invite.go | 5 ++++- roomserver/internal/query/query.go | 8 +++---- roomserver/roomserver_test.go | 2 +- roomserver/state/state.go | 13 ++++++++--- roomserver/storage/interface.go | 4 ++-- roomserver/storage/shared/room_updater.go | 7 ++++-- roomserver/storage/shared/storage.go | 22 ++++++++++--------- roomserver/types/types.go | 3 +++ 16 files changed, 74 insertions(+), 36 deletions(-) diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 1278b1cc8..3a4255bae 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -91,7 +91,7 @@ func main() { } var eventEntries []types.Event - eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs) + eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { panic(err) } @@ -149,7 +149,7 @@ func main() { } fmt.Println("Fetching", len(eventNIDMap), "state events") - eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs) + eventEntries, err := roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { panic(err) } diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 3d2beab37..24958091b 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -219,7 +219,12 @@ func loadAuthEvents( eventNIDs = append(eventNIDs, eventNID) } } - if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil { + + if roomInfo == nil { + err = types.ErrorInvalidRoomInfo + return + } + if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index ea0074fc4..95397cd5e 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -86,7 +86,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, err } - events, err := db.Events(ctx, info, eventNIDs) + events, err := db.Events(ctx, info.RoomVersion, eventNIDs) if err != nil { return false, err } @@ -183,7 +183,10 @@ func GetMembershipsAtState( util.Unique(eventNIDs) // Get all of the events in this state - stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) + if roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { return nil, err } @@ -235,7 +238,10 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types func LoadEvents( ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.PDU, error) { - stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) + if roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 9ae29c544..c8f5737ff 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -805,7 +805,10 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r return err } - memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs) + if roomInfo == nil { + return types.ErrorInvalidRoomInfo + } + memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, membershipNIDs) if err != nil { return err } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 947f6c150..98d7d13b1 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -55,7 +55,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := updater.Events(ctx, nil, eventNIDs) + events, err := updater.Events(ctx, "", eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 89ba07569..8a1235221 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -398,7 +398,10 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even for _, entry := range stateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs) + if t.roomInfo == nil { + return nil + } + stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomVersion, stateEventNIDs) if err != nil { t.log.WithError(err).Warnf("failed to load state events locally") return nil diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 70668a201..375eefbec 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -60,7 +60,7 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil, err } - memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, memberNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 8dbfad9bc..fb579f03a 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -533,7 +533,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, roomNID = nid.RoomNID } } - eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs) + eventsWithNids, err := b.db.Events(ctx, b.roomInfo.RoomVersion, eventNIDs) if err != nil { logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err @@ -563,7 +563,10 @@ func joinEventsFromHistoryVisibility( } // Get all of the events in this state - stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) + if roomInfo == nil { + return nil, gomatrixserverlib.HistoryVisibilityJoined, types.ErrorInvalidRoomInfo + } + stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { // even though the default should be shared, restricting the visibility to joined // feels more secure here. @@ -586,7 +589,7 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, visibility, err } - evs, err := db.Events(ctx, roomInfo, joinEventNIDs) + evs, err := db.Events(ctx, roomInfo.RoomVersion, joinEventNIDs) return evs, visibility, err } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index a920811d8..db0b53fef 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -269,7 +269,10 @@ func buildInviteStrippedState( for _, stateNID := range stateEntries { stateNIDs = append(stateNIDs, stateNID.EventNID) } - stateEvents, err := db.Events(ctx, info, stateNIDs) + if info == nil { + return nil, types.ErrorInvalidRoomInfo + } + stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c74bf21bf..27c0dd0c0 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -212,7 +212,7 @@ func (r *Queryer) QueryMembershipForUser( response.IsInRoom = stillInRoom response.HasBeenInRoom = true - evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID}) + evs, err := r.DB.Events(ctx, info.RoomVersion, []types.EventNID{membershipEventNID}) if err != nil { return err } @@ -344,7 +344,7 @@ func (r *Queryer) QueryMembershipsForRoom( } return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } - events, err = r.DB.Events(ctx, info, eventNIDs) + events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs) if err != nil { return fmt.Errorf("r.DB.Events: %w", err) } @@ -383,7 +383,7 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - events, err = r.DB.Events(ctx, info, eventNIDs) + events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs) } else { stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) if err != nil { @@ -967,7 +967,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query // For each of the joined users, let's see if we can get a valid // membership event. for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID}) + events, err := r.DB.Events(ctx, roomInfo.RoomVersion, []types.EventNID{joinNID}) if err != nil || len(events) != 1 { continue } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index c0f3e12db..d19ebebe4 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -571,7 +571,7 @@ func TestRedaction(t *testing.T) { if ev.Type() == spec.MRoomRedaction { nids, err := db.EventNIDs(ctx, []string{ev.Redacts()}) assert.NoError(t, err) - evs, err := db.Events(ctx, roomInfo, []types.EventNID{nids[ev.Redacts()].EventNID}) + evs, err := db.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[ev.Redacts()].EventNID}) assert.NoError(t, err) assert.Equal(t, 1, len(evs)) assert.Equal(t, tc.wantRedacted, evs[0].Redacted()) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index b2a8a8d90..f38d8f96a 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -41,7 +41,7 @@ type StateResolutionStorage interface { StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) } @@ -85,7 +85,10 @@ func (p *StateResolution) Resolve(ctx context.Context, eventID string) (*gomatri return nil, fmt.Errorf("unable to find power level event") } - events, err := p.db.Events(ctx, p.roomInfo, []types.EventNID{plNID}) + if p.roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + events, err := p.db.Events(ctx, p.roomInfo.RoomVersion, []types.EventNID{plNID}) if err != nil { return nil, err } @@ -1134,7 +1137,11 @@ func (v *StateResolution) loadStateEvents( eventNIDs = append(eventNIDs, entry.EventNID) } } - events, err := v.db.Events(ctx, v.roomInfo, eventNIDs) + + if v.roomInfo == nil { + return nil, nil, types.ErrorInvalidRoomInfo + } + events, err := v.db.Events(ctx, v.roomInfo.RoomVersion, eventNIDs) if err != nil { return nil, nil, err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 8da6b350e..6bc4ce9ab 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -72,7 +72,7 @@ type Database interface { ) ([]types.StateEntryList, error) // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) @@ -224,7 +224,7 @@ type EventDatabase interface { SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) - Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error // (nil if there was nothing to do) MaybeRedactEvent( diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index dc1db0825..5a20c67b3 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -116,8 +116,11 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent }) } -func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { - return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs) +func (u *RoomUpdater) Events(ctx context.Context, _ gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) { + if u.roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + return u.d.events(ctx, u.txn, u.roomInfo.RoomVersion, eventNIDs) } func (u *RoomUpdater) SnapshotNIDFromEventID( diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index aa8e7341a..60e46c478 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -392,7 +392,10 @@ func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo nids = append(nids, nid.EventNID) } - return d.events(ctx, txn, roomInfo, nids) + if roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + return d.events(ctx, txn, roomInfo.RoomVersion, nids) } func (d *Database) LatestEventIDs( @@ -531,17 +534,13 @@ func (d *Database) GetInvitesForUser( return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } -func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { - return d.events(ctx, nil, roomInfo, eventNIDs) +func (d *EventDatabase) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) { + return d.events(ctx, nil, roomVersion, eventNIDs) } func (d *EventDatabase) events( - ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs, + ctx context.Context, txn *sql.Tx, roomVersion gomatrixserverlib.RoomVersion, inputEventNIDs types.EventNIDs, ) ([]types.Event, error) { - if roomInfo == nil { // this should never happen - return nil, fmt.Errorf("unable to parse events without roomInfo") - } - sort.Sort(inputEventNIDs) events := make(map[types.EventNID]gomatrixserverlib.PDU, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) @@ -579,7 +578,7 @@ func (d *EventDatabase) events( eventIDs = map[types.EventNID]string{} } - verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) if err != nil { return nil, err } @@ -1107,7 +1106,10 @@ func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, if len(nids) == 0 { return nil } - evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID}) + if roomInfo == nil { + return nil + } + evs, err := d.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[eventID].EventNID}) if err != nil { return nil } diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 349345854..e986b9da7 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -17,6 +17,7 @@ package types import ( "encoding/json" + "fmt" "sort" "strings" "sync" @@ -328,3 +329,5 @@ func (r *RoomInfo) CopyFrom(r2 *RoomInfo) { r.stateSnapshotNID = r2.stateSnapshotNID r.isStub = r2.isStub } + +var ErrorInvalidRoomInfo = fmt.Errorf("room info is invalid") From 0489d16f95a3d9f1f5bc532e2060bd2482d7b156 Mon Sep 17 00:00:00 2001 From: devonh Date: Tue, 9 May 2023 22:46:49 +0000 Subject: [PATCH 02/35] Move json errors over to gmsl (#3080) --- clientapi/auth/auth.go | 10 +- clientapi/auth/login.go | 8 +- clientapi/auth/login_test.go | 6 +- clientapi/auth/login_token.go | 6 +- clientapi/auth/password.go | 16 +- clientapi/auth/user_interactive.go | 12 +- clientapi/httputil/httputil.go | 8 +- clientapi/jsonerror/jsonerror.go | 246 ------------------ clientapi/jsonerror/jsonerror_test.go | 44 ---- clientapi/routing/account_data.go | 18 +- clientapi/routing/admin.go | 29 +-- clientapi/routing/admin_whois.go | 6 +- clientapi/routing/aliases.go | 6 +- clientapi/routing/createroom.go | 55 ++-- clientapi/routing/deactivate.go | 8 +- clientapi/routing/device.go | 30 +-- clientapi/routing/directory.go | 41 ++- clientapi/routing/directory_public.go | 11 +- clientapi/routing/joined_rooms.go | 4 +- clientapi/routing/joinroom.go | 15 +- clientapi/routing/key_backup.go | 20 +- clientapi/routing/key_crosssigning.go | 26 +- clientapi/routing/keys.go | 18 +- clientapi/routing/leaveroom.go | 6 +- clientapi/routing/login.go | 10 +- clientapi/routing/logout.go | 6 +- clientapi/routing/membership.go | 57 ++-- clientapi/routing/notification.go | 8 +- clientapi/routing/openid.go | 6 +- clientapi/routing/password.go | 12 +- clientapi/routing/peekroom.go | 11 +- clientapi/routing/presence.go | 11 +- clientapi/routing/profile.go | 40 ++- clientapi/routing/pusher.go | 12 +- clientapi/routing/pushrules.go | 60 ++--- clientapi/routing/receipt.go | 3 +- clientapi/routing/redaction.go | 19 +- clientapi/routing/register.go | 65 +++-- clientapi/routing/register_test.go | 24 +- clientapi/routing/room_tagging.go | 18 +- clientapi/routing/routing.go | 23 +- clientapi/routing/sendevent.go | 33 ++- clientapi/routing/sendtodevice.go | 4 +- clientapi/routing/sendtyping.go | 6 +- clientapi/routing/server_notices.go | 12 +- clientapi/routing/state.go | 25 +- clientapi/routing/thirdparty.go | 14 +- clientapi/routing/threepid.go | 28 +- clientapi/routing/upgrade_room.go | 10 +- clientapi/routing/voip.go | 4 +- federationapi/federationapi_test.go | 5 +- federationapi/routing/backfill.go | 11 +- federationapi/routing/devices.go | 7 +- federationapi/routing/eventauth.go | 4 +- federationapi/routing/events.go | 6 +- federationapi/routing/invite.go | 38 +-- federationapi/routing/join.go | 85 +++--- federationapi/routing/keys.go | 27 +- federationapi/routing/leave.go | 59 ++--- federationapi/routing/missingevents.go | 6 +- federationapi/routing/openid.go | 6 +- federationapi/routing/peek.go | 6 +- federationapi/routing/profile.go | 10 +- federationapi/routing/publicrooms.go | 7 +- federationapi/routing/query.go | 16 +- federationapi/routing/routing.go | 35 ++- federationapi/routing/send.go | 6 +- federationapi/routing/state.go | 12 +- federationapi/routing/threepid.go | 37 ++- go.mod | 10 +- go.sum | 30 +-- internal/httputil/httpapi.go | 6 +- internal/httputil/rate_limiting.go | 4 +- internal/transactionrequest.go | 3 +- internal/transactionrequest_test.go | 26 +- internal/validate.go | 9 +- internal/validate_test.go | 17 +- mediaapi/routing/download.go | 13 +- mediaapi/routing/upload.go | 20 +- relayapi/routing/relaytxn.go | 5 +- relayapi/routing/routing.go | 5 +- relayapi/routing/sendrelay.go | 7 +- roomserver/api/api.go | 2 +- roomserver/api/wrapper.go | 4 +- roomserver/internal/input/input.go | 10 +- roomserver/internal/input/input_events.go | 3 +- roomserver/internal/perform/perform_admin.go | 8 +- roomserver/internal/perform/perform_invite.go | 4 +- roomserver/internal/perform/perform_join.go | 4 +- roomserver/internal/perform/perform_leave.go | 7 +- setup/mscs/msc2836/msc2836.go | 9 +- setup/mscs/msc2946/msc2946.go | 11 +- syncapi/internal/keychange_test.go | 15 +- syncapi/routing/context.go | 27 +- syncapi/routing/filter.go | 20 +- syncapi/routing/getevent.go | 12 +- syncapi/routing/memberships.go | 18 +- syncapi/routing/messages.go | 27 +- syncapi/routing/relations.go | 6 +- syncapi/routing/routing.go | 5 +- syncapi/routing/search.go | 21 +- syncapi/sync/requestpool.go | 17 +- userapi/api/api.go | 20 +- userapi/consumers/signingkeyupdate.go | 5 +- userapi/internal/cross_signing.go | 40 ++- userapi/internal/device_list_update.go | 4 +- userapi/internal/device_list_update_test.go | 3 +- userapi/internal/key_api.go | 12 +- userapi/internal/user_api.go | 3 +- 109 files changed, 808 insertions(+), 1217 deletions(-) delete mode 100644 clientapi/jsonerror/jsonerror.go delete mode 100644 clientapi/jsonerror/jsonerror_test.go diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index 93345f4b9..479b9ac7b 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -23,8 +23,8 @@ import ( "net/http" "strings" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -58,7 +58,7 @@ func VerifyUserFromRequest( if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.MissingToken(err.Error()), + JSON: spec.MissingToken(err.Error()), } } var res api.QueryAccessTokenResponse @@ -68,21 +68,21 @@ func VerifyUserFromRequest( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") - jsonErr := jsonerror.InternalServerError() + jsonErr := spec.InternalServerError() return nil, &jsonErr } if res.Err != "" { if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(res.Err), + JSON: spec.Forbidden(res.Err), } } } if res.Device == nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.UnknownToken("Unknown token"), + JSON: spec.UnknownToken("Unknown token"), } } return res.Device, nil diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go index 5467e814d..77835614e 100644 --- a/clientapi/auth/login.go +++ b/clientapi/auth/login.go @@ -21,9 +21,9 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -37,7 +37,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.U if err != nil { err := &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + JSON: spec.BadJSON("Reading request body failed: " + err.Error()), } return nil, nil, err } @@ -48,7 +48,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.U if err := json.Unmarshal(reqBytes, &header); err != nil { err := &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + JSON: spec.BadJSON("Reading request body failed: " + err.Error()), } return nil, nil, err } @@ -68,7 +68,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.U default: err := util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("unhandled login type: " + header.Type), + JSON: spec.InvalidParam("unhandled login type: " + header.Type), } return nil, nil, &err } diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index c91cba241..eb87d5e8e 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -21,11 +21,11 @@ import ( "strings" "testing" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -140,7 +140,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { "type": "m.login.invalid", "device_id": "adevice" }`, - WantErrCode: "M_INVALID_ARGUMENT_VALUE", + WantErrCode: "M_INVALID_PARAM", }, } for _, tst := range tsts { @@ -157,7 +157,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { if errRes == nil { cleanup(ctx, nil) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) - } else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode { + } else if merr, ok := errRes.JSON.(*spec.MatrixError); ok && merr.ErrCode != tst.WantErrCode { t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) } }) diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go index 845eb5de9..073f728d6 100644 --- a/clientapi/auth/login_token.go +++ b/clientapi/auth/login_token.go @@ -20,9 +20,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -48,13 +48,13 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L var res uapi.QueryLoginTokenResponse if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed") - jsonErr := jsonerror.InternalServerError() + jsonErr := spec.InternalServerError() return nil, nil, &jsonErr } if res.Data == nil { return nil, nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("invalid login token"), + JSON: spec.Forbidden("invalid login token"), } } diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index f2b0383ab..fb7def024 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -21,10 +21,10 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -65,26 +65,26 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("A username must be supplied."), + JSON: spec.BadJSON("A username must be supplied."), } } if len(r.Password) == 0 { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("A password must be supplied."), + JSON: spec.BadJSON("A password must be supplied."), } } localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.InvalidUsername(err.Error()), + JSON: spec.InvalidUsername(err.Error()), } } if !t.Config.Matrix.IsLocalServerName(domain) { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.InvalidUsername("The server name is not known."), + JSON: spec.InvalidUsername("The server name is not known."), } } // Squash username to all lowercase letters @@ -97,7 +97,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Unable to fetch account by password."), + JSON: spec.Unknown("Unable to fetch account by password."), } } @@ -112,7 +112,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Unable to fetch account by password."), + JSON: spec.Unknown("Unable to fetch account by password."), } } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows @@ -120,7 +120,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, if !res.Exists { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), + JSON: spec.Forbidden("The username or password was incorrect or the account does not exist."), } } } diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 9971bf8a4..58d34865f 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -20,9 +20,9 @@ import ( "net/http" "sync" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -178,7 +178,7 @@ func (u *UserInteractive) NewSession() *util.JSONResponse { sessionID, err := GenerateAccessToken() if err != nil { logrus.WithError(err).Error("failed to generate session ID") - res := jsonerror.InternalServerError() + res := spec.InternalServerError() return &res } u.Lock() @@ -193,14 +193,14 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter mixedObjects := make(map[string]interface{}) b, err := json.Marshal(response) if err != nil { - ise := jsonerror.InternalServerError() + ise := spec.InternalServerError() return &ise } _ = json.Unmarshal(b, &mixedObjects) challenge := u.challenge(sessionID) b, err = json.Marshal(challenge.JSON) if err != nil { - ise := jsonerror.InternalServerError() + ise := spec.InternalServerError() return &ise } _ = json.Unmarshal(b, &mixedObjects) @@ -234,7 +234,7 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * if !ok { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Unknown auth.type: " + authType), + JSON: spec.BadJSON("Unknown auth.type: " + authType), } } @@ -250,7 +250,7 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * if !u.IsSingleStageFlow(authType) { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("The auth.session is missing or unknown."), + JSON: spec.Unknown("The auth.session is missing or unknown."), } } } diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go index 74f84f1e7..aea0c3db6 100644 --- a/clientapi/httputil/httputil.go +++ b/clientapi/httputil/httputil.go @@ -20,7 +20,7 @@ import ( "net/http" "unicode/utf8" - "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -32,7 +32,7 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon body, err := io.ReadAll(req.Body) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("io.ReadAll failed") - resp := jsonerror.InternalServerError() + resp := spec.InternalServerError() return &resp } @@ -43,7 +43,7 @@ func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { if !utf8.Valid(body) { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("Body contains invalid UTF-8"), + JSON: spec.NotJSON("Body contains invalid UTF-8"), } } @@ -53,7 +53,7 @@ func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { // valid JSON with incorrect types for values. return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } return nil diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go deleted file mode 100644 index 436e168ab..000000000 --- a/clientapi/jsonerror/jsonerror.go +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package jsonerror - -import ( - "context" - "fmt" - "net/http" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" -) - -// MatrixError represents the "standard error response" in Matrix. -// http://matrix.org/docs/spec/client_server/r0.2.0.html#api-standards -type MatrixError struct { - ErrCode string `json:"errcode"` - Err string `json:"error"` -} - -func (e MatrixError) Error() string { - return fmt.Sprintf("%s: %s", e.ErrCode, e.Err) -} - -// InternalServerError returns a 500 Internal Server Error in a matrix-compliant -// format. -func InternalServerError() util.JSONResponse { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: Unknown("Internal Server Error"), - } -} - -// Unknown is an unexpected error -func Unknown(msg string) *MatrixError { - return &MatrixError{"M_UNKNOWN", msg} -} - -// Forbidden is an error when the client tries to access a resource -// they are not allowed to access. -func Forbidden(msg string) *MatrixError { - return &MatrixError{"M_FORBIDDEN", msg} -} - -// BadJSON is an error when the client supplies malformed JSON. -func BadJSON(msg string) *MatrixError { - return &MatrixError{"M_BAD_JSON", msg} -} - -// BadAlias is an error when the client supplies a bad alias. -func BadAlias(msg string) *MatrixError { - return &MatrixError{"M_BAD_ALIAS", msg} -} - -// NotJSON is an error when the client supplies something that is not JSON -// to a JSON endpoint. -func NotJSON(msg string) *MatrixError { - return &MatrixError{"M_NOT_JSON", msg} -} - -// NotFound is an error when the client tries to access an unknown resource. -func NotFound(msg string) *MatrixError { - return &MatrixError{"M_NOT_FOUND", msg} -} - -// MissingArgument is an error when the client tries to access a resource -// without providing an argument that is required. -func MissingArgument(msg string) *MatrixError { - return &MatrixError{"M_MISSING_ARGUMENT", msg} -} - -// InvalidArgumentValue is an error when the client tries to provide an -// invalid value for a valid argument -func InvalidArgumentValue(msg string) *MatrixError { - return &MatrixError{"M_INVALID_ARGUMENT_VALUE", msg} -} - -// MissingToken is an error when the client tries to access a resource which -// requires authentication without supplying credentials. -func MissingToken(msg string) *MatrixError { - return &MatrixError{"M_MISSING_TOKEN", msg} -} - -// UnknownToken is an error when the client tries to access a resource which -// requires authentication and supplies an unrecognised token -func UnknownToken(msg string) *MatrixError { - return &MatrixError{"M_UNKNOWN_TOKEN", msg} -} - -// WeakPassword is an error which is returned when the client tries to register -// using a weak password. http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based -func WeakPassword(msg string) *MatrixError { - return &MatrixError{"M_WEAK_PASSWORD", msg} -} - -// InvalidUsername is an error returned when the client tries to register an -// invalid username -func InvalidUsername(msg string) *MatrixError { - return &MatrixError{"M_INVALID_USERNAME", msg} -} - -// UserInUse is an error returned when the client tries to register an -// username that already exists -func UserInUse(msg string) *MatrixError { - return &MatrixError{"M_USER_IN_USE", msg} -} - -// RoomInUse is an error returned when the client tries to make a room -// that already exists -func RoomInUse(msg string) *MatrixError { - return &MatrixError{"M_ROOM_IN_USE", msg} -} - -// ASExclusive is an error returned when an application service tries to -// register an username that is outside of its registered namespace, or if a -// user attempts to register a username or room alias within an exclusive -// namespace. -func ASExclusive(msg string) *MatrixError { - return &MatrixError{"M_EXCLUSIVE", msg} -} - -// GuestAccessForbidden is an error which is returned when the client is -// forbidden from accessing a resource as a guest. -func GuestAccessForbidden(msg string) *MatrixError { - return &MatrixError{"M_GUEST_ACCESS_FORBIDDEN", msg} -} - -// InvalidSignature is an error which is returned when the client tries -// to upload invalid signatures. -func InvalidSignature(msg string) *MatrixError { - return &MatrixError{"M_INVALID_SIGNATURE", msg} -} - -// InvalidParam is an error that is returned when a parameter was invalid, -// traditionally with cross-signing. -func InvalidParam(msg string) *MatrixError { - return &MatrixError{"M_INVALID_PARAM", msg} -} - -// MissingParam is an error that is returned when a parameter was incorrect, -// traditionally with cross-signing. -func MissingParam(msg string) *MatrixError { - return &MatrixError{"M_MISSING_PARAM", msg} -} - -// UnableToAuthoriseJoin is an error that is returned when a server can't -// determine whether to allow a restricted join or not. -func UnableToAuthoriseJoin(msg string) *MatrixError { - return &MatrixError{"M_UNABLE_TO_AUTHORISE_JOIN", msg} -} - -// LeaveServerNoticeError is an error returned when trying to reject an invite -// for a server notice room. -func LeaveServerNoticeError() *MatrixError { - return &MatrixError{ - ErrCode: "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM", - Err: "You cannot reject this invite", - } -} - -// ErrRoomKeysVersion is an error returned by `PUT /room_keys/keys` -type ErrRoomKeysVersion struct { - MatrixError - CurrentVersion string `json:"current_version"` -} - -// WrongBackupVersionError is an error returned by `PUT /room_keys/keys` -func WrongBackupVersionError(currentVersion string) *ErrRoomKeysVersion { - return &ErrRoomKeysVersion{ - MatrixError: MatrixError{ - ErrCode: "M_WRONG_ROOM_KEYS_VERSION", - Err: "Wrong backup version.", - }, - CurrentVersion: currentVersion, - } -} - -type IncompatibleRoomVersionError struct { - RoomVersion string `json:"room_version"` - Error string `json:"error"` - Code string `json:"errcode"` -} - -// IncompatibleRoomVersion is an error which is returned when the client -// requests a room with a version that is unsupported. -func IncompatibleRoomVersion(roomVersion gomatrixserverlib.RoomVersion) *IncompatibleRoomVersionError { - return &IncompatibleRoomVersionError{ - Code: "M_INCOMPATIBLE_ROOM_VERSION", - RoomVersion: string(roomVersion), - Error: "Your homeserver does not support the features required to join this room", - } -} - -// UnsupportedRoomVersion is an error which is returned when the client -// requests a room with a version that is unsupported. -func UnsupportedRoomVersion(msg string) *MatrixError { - return &MatrixError{"M_UNSUPPORTED_ROOM_VERSION", msg} -} - -// LimitExceededError is a rate-limiting error. -type LimitExceededError struct { - MatrixError - RetryAfterMS int64 `json:"retry_after_ms,omitempty"` -} - -// LimitExceeded is an error when the client tries to send events too quickly. -func LimitExceeded(msg string, retryAfterMS int64) *LimitExceededError { - return &LimitExceededError{ - MatrixError: MatrixError{"M_LIMIT_EXCEEDED", msg}, - RetryAfterMS: retryAfterMS, - } -} - -// NotTrusted is an error which is returned when the client asks the server to -// proxy a request (e.g. 3PID association) to a server that isn't trusted -func NotTrusted(serverName string) *MatrixError { - return &MatrixError{ - ErrCode: "M_SERVER_NOT_TRUSTED", - Err: fmt.Sprintf("Untrusted server '%s'", serverName), - } -} - -// InternalAPIError is returned when Dendrite failed to reach an internal API. -func InternalAPIError(ctx context.Context, err error) util.JSONResponse { - logrus.WithContext(ctx).WithError(err).Error("Error reaching an internal API") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: &MatrixError{ - ErrCode: "M_INTERNAL_SERVER_ERROR", - Err: "Dendrite encountered an error reaching an internal API.", - }, - } -} diff --git a/clientapi/jsonerror/jsonerror_test.go b/clientapi/jsonerror/jsonerror_test.go deleted file mode 100644 index 9f3754cbc..000000000 --- a/clientapi/jsonerror/jsonerror_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package jsonerror - -import ( - "encoding/json" - "testing" -) - -func TestLimitExceeded(t *testing.T) { - e := LimitExceeded("too fast", 5000) - jsonBytes, err := json.Marshal(&e) - if err != nil { - t.Fatalf("TestLimitExceeded: Failed to marshal LimitExceeded error. %s", err.Error()) - } - want := `{"errcode":"M_LIMIT_EXCEEDED","error":"too fast","retry_after_ms":5000}` - if string(jsonBytes) != want { - t.Errorf("TestLimitExceeded: want %s, got %s", want, string(jsonBytes)) - } -} - -func TestForbidden(t *testing.T) { - e := Forbidden("you shall not pass") - jsonBytes, err := json.Marshal(&e) - if err != nil { - t.Fatalf("TestForbidden: Failed to marshal Forbidden error. %s", err.Error()) - } - want := `{"errcode":"M_FORBIDDEN","error":"you shall not pass"}` - if string(jsonBytes) != want { - t.Errorf("TestForbidden: want %s, got %s", want, string(jsonBytes)) - } -} diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 4742b1240..572b28efb 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -21,11 +21,11 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -38,7 +38,7 @@ func GetAccountData( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -69,7 +69,7 @@ func GetAccountData( return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("data not found"), + JSON: spec.NotFound("data not found"), } } @@ -81,7 +81,7 @@ func SaveAccountData( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -90,27 +90,27 @@ func SaveAccountData( if req.Body == http.NoBody { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("Content not JSON"), + JSON: spec.NotJSON("Content not JSON"), } } if dataType == "m.fully_read" || dataType == "m.push_rules" { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(fmt.Sprintf("Unable to modify %q using this API", dataType)), + JSON: spec.Forbidden(fmt.Sprintf("Unable to modify %q using this API", dataType)), } } body, err := io.ReadAll(req.Body) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("io.ReadAll failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !json.Valid(body) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Bad JSON content"), + JSON: spec.BadJSON("Bad JSON content"), } } @@ -157,7 +157,7 @@ func SaveReadMarker( if r.FullyRead != "" { data, err := json.Marshal(fullyReadEvent{EventID: r.FullyRead}) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } dataReq := api.InputAccountDataRequest{ diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index a5fc4ec48..4d2cea681 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -17,7 +17,6 @@ import ( "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -37,7 +36,7 @@ func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAP case eventutil.ErrRoomNoExists: return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(err.Error()), + JSON: spec.NotFound(err.Error()), } default: logrus.WithError(err).WithField("roomID", vars["roomID"]).Error("Failed to evacuate room") @@ -91,7 +90,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De if req.Body == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Missing request body"), + JSON: spec.Unknown("Missing request body"), } } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -104,7 +103,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } accAvailableResp := &api.QueryAccountAvailabilityResponse{} @@ -114,13 +113,13 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De }, accAvailableResp); err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalAPIError(req.Context(), err), + JSON: spec.InternalServerError(), } } if accAvailableResp.Available { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.Unknown("User does not exist"), + JSON: spec.Unknown("User does not exist"), } } request := struct { @@ -129,13 +128,13 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), + JSON: spec.Unknown("Failed to decode request body: " + err.Error()), } } if request.Password == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting non-empty password."), + JSON: spec.MissingParam("Expecting non-empty password."), } } @@ -153,7 +152,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to perform password update: " + err.Error()), + JSON: spec.Unknown("Failed to perform password update: " + err.Error()), } } return util.JSONResponse{ @@ -170,7 +169,7 @@ func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) if err != nil { logrus.WithError(err).Error("failed to publish nats message") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ Code: http.StatusOK, @@ -192,7 +191,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien if cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam("Can not mark local device list as stale"), + JSON: spec.InvalidParam("Can not mark local device list as stale"), } } @@ -203,7 +202,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to mark device list as stale: %s", err)), + JSON: spec.Unknown(fmt.Sprintf("Failed to mark device list as stale: %s", err)), } } return util.JSONResponse{ @@ -221,21 +220,21 @@ func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverA if !ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting room ID."), + JSON: spec.MissingParam("Expecting room ID."), } } serverName, ok := vars["serverName"] if !ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting remote server name."), + JSON: spec.MissingParam("Expecting remote server name."), } } if err = rsAPI.PerformAdminDownloadState(req.Context(), roomID, device.UserID, spec.ServerName(serverName)); err != nil { if errors.Is(err, eventutil.ErrRoomNoExists) { return util.JSONResponse{ Code: 200, - JSON: jsonerror.NotFound(eventutil.ErrRoomNoExists.Error()), + JSON: spec.NotFound(eventutil.ErrRoomNoExists.Error()), } } logrus.WithError(err).WithFields(logrus.Fields{ diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go index f1cbd3467..cb2b8a26b 100644 --- a/clientapi/routing/admin_whois.go +++ b/clientapi/routing/admin_whois.go @@ -17,8 +17,8 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -51,7 +51,7 @@ func GetAdminWhois( if !allowed { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -61,7 +61,7 @@ func GetAdminWhois( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("GetAdminWhois failed to query user devices") - return jsonerror.InternalServerError() + return spec.InternalServerError() } devices := make(map[string]deviceInfo) diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index 5c2df79dc..87c1f9ffd 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -19,12 +19,10 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" - "github.com/matrix-org/util" ) @@ -64,12 +62,12 @@ func GetAliases( var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !queryRes.IsInRoom { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You aren't a member of this room."), + JSON: spec.Forbidden("You aren't a member of this room."), } } } diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index abf5b4f46..f0cdd6f5a 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -32,7 +32,6 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -75,7 +74,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { if strings.ContainsAny(r.RoomAliasName, whitespace+":") { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("room_alias_name cannot contain whitespace or ':'"), + JSON: spec.BadJSON("room_alias_name cannot contain whitespace or ':'"), } } for _, userID := range r.Invite { @@ -87,7 +86,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { if _, _, err := gomatrixserverlib.SplitID('@', userID); err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("user id must be in the form @localpart:domain"), + JSON: spec.BadJSON("user id must be in the form @localpart:domain"), } } } @@ -96,7 +95,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { default: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("preset must be any of 'private_chat', 'trusted_private_chat', 'public_chat'"), + JSON: spec.BadJSON("preset must be any of 'private_chat', 'trusted_private_chat', 'public_chat'"), } } @@ -108,7 +107,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { if err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("malformed creation_content"), + JSON: spec.BadJSON("malformed creation_content"), } } @@ -117,7 +116,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { if err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("malformed creation_content"), + JSON: spec.BadJSON("malformed creation_content"), } } @@ -156,7 +155,7 @@ func CreateRoom( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } return createRoom(req.Context(), r, device, cfg, profileAPI, rsAPI, asAPI, evTime) @@ -175,12 +174,12 @@ func createRoom( _, userDomain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !cfg.Matrix.IsLocalServerName(userDomain) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(fmt.Sprintf("User domain %q not configured locally", userDomain)), + JSON: spec.Forbidden(fmt.Sprintf("User domain %q not configured locally", userDomain)), } } @@ -200,7 +199,7 @@ func createRoom( if roomVersionError != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(roomVersionError.Error()), + JSON: spec.UnsupportedRoomVersion(roomVersionError.Error()), } } roomVersion = candidateVersion @@ -219,7 +218,7 @@ func createRoom( profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI) if err != nil { util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } createContent := map[string]interface{}{} @@ -228,7 +227,7 @@ func createRoom( util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("invalid create content"), + JSON: spec.BadJSON("invalid create content"), } } } @@ -249,7 +248,7 @@ func createRoom( util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("malformed power_level_content_override"), + JSON: spec.BadJSON("malformed power_level_content_override"), } } } @@ -343,12 +342,12 @@ func createRoom( err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) if err != nil { util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if aliasResp.RoomID != "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.RoomInUse("Room ID already exists."), + JSON: spec.RoomInUse("Room ID already exists."), } } @@ -437,7 +436,7 @@ func createRoom( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("unknown room version"), + JSON: spec.BadJSON("unknown room version"), } } @@ -456,7 +455,7 @@ func createRoom( err = builder.SetContent(e.Content) if err != nil { util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if i > 0 { builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} @@ -464,17 +463,17 @@ func createRoom( var ev gomatrixserverlib.PDU if err = builder.AddAuthEvents(&authEvents); err != nil { util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } ev, err = builder.Build(evTime, userDomain, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildEvent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // Add the event to the list of auth events @@ -482,7 +481,7 @@ func createRoom( err = authEvents.AddEvent(ev) if err != nil { util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } @@ -497,7 +496,7 @@ func createRoom( } if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // TODO(#269): Reserve room alias while we create the room. This stops us @@ -514,13 +513,13 @@ func createRoom( err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) if err != nil { util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if aliasResp.AliasExists { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.RoomInUse("Room alias already exists."), + JSON: spec.RoomInUse("Room alias already exists."), } } } @@ -584,12 +583,12 @@ func createRoom( case roomserverAPI.ErrInvalidID: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(e.Error()), + JSON: spec.Unknown(e.Error()), } case roomserverAPI.ErrNotAllowed: return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(e.Error()), + JSON: spec.Forbidden(e.Error()), } case nil: default: @@ -597,7 +596,7 @@ func createRoom( sentry.CaptureException(err) return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } } @@ -610,7 +609,7 @@ func createRoom( Visibility: spec.Public, }); err != nil { util.GetLogger(ctx).WithError(err).Error("failed to publish room") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } diff --git a/clientapi/routing/deactivate.go b/clientapi/routing/deactivate.go index 3f4f539f6..78cf9fe38 100644 --- a/clientapi/routing/deactivate.go +++ b/clientapi/routing/deactivate.go @@ -5,9 +5,9 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -24,7 +24,7 @@ func Deactivate( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), + JSON: spec.BadJSON("The request body could not be read: " + err.Error()), } } @@ -36,7 +36,7 @@ func Deactivate( localpart, serverName, err := gomatrixserverlib.SplitID('@', login.Username()) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var res api.PerformAccountDeactivationResponse @@ -46,7 +46,7 @@ func Deactivate( }, &res) if err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformAccountDeactivation failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 331bacc3c..6209d8e95 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -22,9 +22,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -60,7 +60,7 @@ func GetDeviceByID( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var targetDevice *api.Device for _, device := range queryRes.Devices { @@ -72,7 +72,7 @@ func GetDeviceByID( if targetDevice == nil { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown device"), + JSON: spec.NotFound("Unknown device"), } } @@ -97,7 +97,7 @@ func GetDevicesByLocalpart( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } res := devicesJSON{} @@ -139,12 +139,12 @@ func UpdateDeviceByID( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !performRes.DeviceExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.Forbidden("device does not exist"), + JSON: spec.Forbidden("device does not exist"), } } @@ -174,7 +174,7 @@ func DeleteDeviceById( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), + JSON: spec.BadJSON("The request body could not be read: " + err.Error()), } } @@ -184,7 +184,7 @@ func DeleteDeviceById( if dev != deviceID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("session and device mismatch"), + JSON: spec.Forbidden("session and device mismatch"), } } } @@ -206,7 +206,7 @@ func DeleteDeviceById( localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // make sure that the access token being used matches the login creds used for user interactive auth, else @@ -214,7 +214,7 @@ func DeleteDeviceById( if login.Username() != localpart && login.Username() != device.UserID { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("Cannot delete another user's device"), + JSON: spec.Forbidden("Cannot delete another user's device"), } } @@ -224,7 +224,7 @@ func DeleteDeviceById( DeviceIDs: []string{deviceID}, }, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } deleteOK = true @@ -245,7 +245,7 @@ func DeleteDevices( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), + JSON: spec.BadJSON("The request body could not be read: " + err.Error()), } } defer req.Body.Close() // nolint:errcheck @@ -259,14 +259,14 @@ func DeleteDevices( if login.Username() != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("unable to delete devices for other user"), + JSON: spec.Forbidden("unable to delete devices for other user"), } } payload := devicesDeleteJSON{} if err = json.Unmarshal(bodyBytes, &payload); err != nil { util.GetLogger(ctx).WithError(err).Error("unable to unmarshal device deletion request") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var res api.PerformDeviceDeletionResponse @@ -275,7 +275,7 @@ func DeleteDevices( DeviceIDs: payload.Devices, }, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 11ae5739c..0ca9475d7 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -56,7 +55,7 @@ func DirectoryRoom( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Room alias must be in the form '#localpart:domain'"), + JSON: spec.BadJSON("Room alias must be in the form '#localpart:domain'"), } } @@ -70,7 +69,7 @@ func DirectoryRoom( queryRes := &roomserverAPI.GetRoomIDForAliasResponse{} if err = rsAPI.GetRoomIDForAlias(req.Context(), queryReq, queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.GetRoomIDForAlias failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } res.RoomID = queryRes.RoomID @@ -84,7 +83,7 @@ func DirectoryRoom( // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. util.GetLogger(req.Context()).WithError(fedErr).Error("federation.LookupRoomAlias failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } res.RoomID = fedRes.RoomID res.fillServers(fedRes.Servers) @@ -93,7 +92,7 @@ func DirectoryRoom( if res.RoomID == "" { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound( + JSON: spec.NotFound( fmt.Sprintf("Room alias %s not found", roomAlias), ), } @@ -103,7 +102,7 @@ func DirectoryRoom( var joinedHostsRes federationAPI.QueryJoinedHostServerNamesInRoomResponse if err = fedSenderAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &joinedHostsReq, &joinedHostsRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("fedSenderAPI.QueryJoinedHostServerNamesInRoom failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } res.fillServers(joinedHostsRes.ServerNames) } @@ -126,14 +125,14 @@ func SetLocalAlias( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Room alias must be in the form '#localpart:domain'"), + JSON: spec.BadJSON("Room alias must be in the form '#localpart:domain'"), } } if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Alias must be on local homeserver"), + JSON: spec.Forbidden("Alias must be on local homeserver"), } } @@ -146,7 +145,7 @@ func SetLocalAlias( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("User ID must be in the form '@localpart:domain'"), + JSON: spec.BadJSON("User ID must be in the form '@localpart:domain'"), } } for _, appservice := range cfg.Derived.ApplicationServices { @@ -158,7 +157,7 @@ func SetLocalAlias( if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive("Alias is reserved by an application service"), + JSON: spec.ASExclusive("Alias is reserved by an application service"), } } } @@ -181,13 +180,13 @@ func SetLocalAlias( var queryRes roomserverAPI.SetRoomAliasResponse if err := rsAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if queryRes.AliasExists { return util.JSONResponse{ Code: http.StatusConflict, - JSON: jsonerror.Unknown("The alias " + alias + " already exists."), + JSON: spec.Unknown("The alias " + alias + " already exists."), } } @@ -211,20 +210,20 @@ func RemoveLocalAlias( var queryRes roomserverAPI.RemoveRoomAliasResponse if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.RemoveRoomAlias failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !queryRes.Found { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The alias does not exist."), + JSON: spec.NotFound("The alias does not exist."), } } if !queryRes.Removed { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You do not have permission to remove this alias."), + JSON: spec.Forbidden("You do not have permission to remove this alias."), } } @@ -249,7 +248,7 @@ func GetVisibility( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPublishedRooms failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var v roomVisibility @@ -287,7 +286,7 @@ func SetVisibility( err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) if err != nil || len(queryEventsRes.StateEvents) == 0 { util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event @@ -295,7 +294,7 @@ func SetVisibility( if power.UserLevel(dev.UserID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID doesn't have power level to change visibility"), + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), } } @@ -309,7 +308,7 @@ func SetVisibility( Visibility: v.Visibility, }); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to publish room") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -325,7 +324,7 @@ func SetVisibilityAS( if dev.AccountType != userapi.AccountTypeAppService { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Only appservice may use this endpoint"), + JSON: spec.Forbidden("Only appservice may use this endpoint"), } } var v roomVisibility @@ -345,7 +344,7 @@ func SetVisibilityAS( AppserviceID: dev.AppserviceID, }); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to publish room") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index c150d908e..9718ccab6 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -29,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" ) @@ -68,7 +67,7 @@ func GetPostPublicRooms( if request.IncludeAllNetworks && request.NetworkID != "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam("include_all_networks and third_party_instance_id can not be used together"), + JSON: spec.InvalidParam("include_all_networks and third_party_instance_id can not be used together"), } } @@ -82,7 +81,7 @@ func GetPostPublicRooms( ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to get public rooms") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ Code: http.StatusOK, @@ -93,7 +92,7 @@ func GetPostPublicRooms( response, err := publicRooms(req.Context(), request, rsAPI, extRoomsProvider) if err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to work out public rooms") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ Code: http.StatusOK, @@ -173,7 +172,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO if httpReq.Method != "GET" && httpReq.Method != "POST" { return &util.JSONResponse{ Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), + JSON: spec.NotFound("Bad method"), } } if httpReq.Method == "GET" { @@ -184,7 +183,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") return &util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON("limit param is not a number"), + JSON: spec.BadJSON("limit param is not a number"), } } request.Limit = int64(limit) diff --git a/clientapi/routing/joined_rooms.go b/clientapi/routing/joined_rooms.go index 4bb353ea9..51a96e4d9 100644 --- a/clientapi/routing/joined_rooms.go +++ b/clientapi/routing/joined_rooms.go @@ -19,9 +19,9 @@ import ( "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type getJoinedRoomsResponse struct { @@ -40,7 +40,7 @@ func GetJoinedRooms( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if res.RoomIDs == nil { res.RoomIDs = []string{} diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 54a9aaa4b..a67d51327 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -22,7 +22,6 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" @@ -75,7 +74,7 @@ func JoinRoomByIDOrAlias( util.GetLogger(req.Context()).Error("Unable to query user profile, no profile found.") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Unable to query user profile, no profile found."), + JSON: spec.Unknown("Unable to query user profile, no profile found."), } default: } @@ -99,12 +98,12 @@ func JoinRoomByIDOrAlias( case roomserverAPI.ErrInvalidID: response = util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(e.Error()), + JSON: spec.Unknown(e.Error()), } case roomserverAPI.ErrNotAllowed: - jsonErr := jsonerror.Forbidden(e.Error()) + jsonErr := spec.Forbidden(e.Error()) if device.AccountType == api.AccountTypeGuest { - jsonErr = jsonerror.GuestAccessForbidden(e.Error()) + jsonErr = spec.GuestAccessForbidden(e.Error()) } response = util.JSONResponse{ Code: http.StatusForbidden, @@ -118,12 +117,12 @@ func JoinRoomByIDOrAlias( default: response = util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } if errors.Is(err, eventutil.ErrRoomNoExists) { response = util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(e.Error()), + JSON: spec.NotFound(e.Error()), } } } @@ -137,7 +136,7 @@ func JoinRoomByIDOrAlias( case <-timer.C: return util.JSONResponse{ Code: http.StatusAccepted, - JSON: jsonerror.Unknown("The room join will continue in the background."), + JSON: spec.Unknown("The room join will continue in the background."), } case result := <-done: // Stop and drain the timer diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index 56b05db15..b7b1cadd2 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -20,8 +20,8 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -64,7 +64,7 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de if len(kb.AuthData) == 0 { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing auth_data"), + JSON: spec.BadJSON("missing auth_data"), } } version, err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ @@ -98,7 +98,7 @@ func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device * if !queryResp.Exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("version not found"), + JSON: spec.NotFound("version not found"), } } return util.JSONResponse{ @@ -128,7 +128,7 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse Algorithm: kb.Algorithm, }) switch e := err.(type) { - case *jsonerror.ErrRoomKeysVersion: + case *spec.ErrRoomKeysVersion: return util.JSONResponse{ Code: http.StatusForbidden, JSON: e, @@ -141,7 +141,7 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse if !performKeyBackupResp.Exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("backup version not found"), + JSON: spec.NotFound("backup version not found"), } } return util.JSONResponse{ @@ -162,7 +162,7 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de if !exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("backup version not found"), + JSON: spec.NotFound("backup version not found"), } } return util.JSONResponse{ @@ -182,7 +182,7 @@ func UploadBackupKeys( }) switch e := err.(type) { - case *jsonerror.ErrRoomKeysVersion: + case *spec.ErrRoomKeysVersion: return util.JSONResponse{ Code: http.StatusForbidden, JSON: e, @@ -194,7 +194,7 @@ func UploadBackupKeys( if !performKeyBackupResp.Exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("backup version not found"), + JSON: spec.NotFound("backup version not found"), } } return util.JSONResponse{ @@ -223,7 +223,7 @@ func GetBackupKeys( if !queryResp.Exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("version not found"), + JSON: spec.NotFound("version not found"), } } if sessionID != "" { @@ -274,6 +274,6 @@ func GetBackupKeys( } return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("keys not found"), + JSON: spec.NotFound("keys not found"), } } diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 267ba1dc5..6bf7c58e3 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -20,9 +20,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -71,31 +71,29 @@ func UploadCrossSigningDeviceKeys( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID - if err := keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } + keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) if err := uploadRes.Error; err != nil { switch { case err.IsInvalidSignature: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidSignature(err.Error()), + JSON: spec.InvalidSignature(err.Error()), } case err.IsMissingParam: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingParam(err.Error()), + JSON: spec.MissingParam(err.Error()), } case err.IsInvalidParam: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam(err.Error()), + JSON: spec.InvalidParam(err.Error()), } default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(err.Error()), + JSON: spec.Unknown(err.Error()), } } } @@ -115,31 +113,29 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie } uploadReq.UserID = device.UserID - if err := keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } + keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes) if err := uploadRes.Error; err != nil { switch { case err.IsInvalidSignature: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidSignature(err.Error()), + JSON: spec.InvalidSignature(err.Error()), } case err.IsMissingParam: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingParam(err.Error()), + JSON: spec.MissingParam(err.Error()), } case err.IsInvalidParam: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam(err.Error()), + JSON: spec.InvalidParam(err.Error()), } default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(err.Error()), + JSON: spec.Unknown(err.Error()), } } } diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 3d60fcc3a..363ae3dc9 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type uploadKeysRequest struct { @@ -67,7 +67,7 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) } if uploadRes.Error != nil { util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if len(uploadRes.KeyErrors) > 0 { util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys") @@ -112,14 +112,12 @@ func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) u return *resErr } queryRes := api.QueryKeysResponse{} - if err := keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ + keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ UserID: device.UserID, UserToDevices: r.DeviceKeys, Timeout: r.GetTimeout(), // TODO: Token? - }, &queryRes); err != nil { - return util.ErrorResponse(err) - } + }, &queryRes) return util.JSONResponse{ Code: 200, JSON: map[string]interface{}{ @@ -152,15 +150,13 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse { return *resErr } claimRes := api.PerformClaimKeysResponse{} - if err := keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ + keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ OneTimeKeys: r.OneTimeKeys, Timeout: r.GetTimeout(), - }, &claimRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } + }, &claimRes) if claimRes.Error != nil { util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ Code: 200, diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index a71661851..fbf148264 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -17,9 +17,9 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -41,12 +41,12 @@ func LeaveRoomByID( if leaveRes.Code != 0 { return util.JSONResponse{ Code: leaveRes.Code, - JSON: jsonerror.LeaveServerNoticeError(), + JSON: spec.LeaveServerNoticeError(), } } return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(err.Error()), + JSON: spec.Unknown(err.Error()), } } diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 778c8c0c3..d326bff7f 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -19,10 +19,10 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -72,7 +72,7 @@ func Login( } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), + JSON: spec.NotFound("Bad method"), } } @@ -83,13 +83,13 @@ func completeAuth( token, err := auth.GenerateAccessToken() if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.GenerateAccessToken failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } localpart, serverName, err := userutil.ParseUsernameParam(login.Username(), cfg) if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.ParseUsernameParam failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var performRes userapi.PerformDeviceCreationResponse @@ -105,7 +105,7 @@ func completeAuth( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create device: " + err.Error()), + JSON: spec.Unknown("failed to create device: " + err.Error()), } } diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 73bae7af7..049c88d57 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -17,8 +17,8 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -33,7 +33,7 @@ func Logout( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -53,7 +53,7 @@ func LogoutAll( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index d696f2b13..9b95ba5d8 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -27,7 +27,6 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -52,7 +51,7 @@ func SendBan( if body.UserID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing user_id"), + JSON: spec.BadJSON("missing user_id"), } } @@ -69,7 +68,7 @@ func SendBan( if !allowedToBan { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You don't have permission to ban this user, power level too low."), + JSON: spec.Forbidden("You don't have permission to ban this user, power level too low."), } } @@ -86,7 +85,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic ) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } serverName := device.UserDomain() @@ -101,7 +100,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic false, ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -122,7 +121,7 @@ func SendKick( if body.UserID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing user_id"), + JSON: spec.BadJSON("missing user_id"), } } @@ -139,7 +138,7 @@ func SendKick( if !allowedToKick { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You don't have permission to kick this user, power level too low."), + JSON: spec.Forbidden("You don't have permission to kick this user, power level too low."), } } @@ -155,7 +154,7 @@ func SendKick( if queryRes.Membership != spec.Join && queryRes.Membership != spec.Invite { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Unknown("cannot /kick banned or left users"), + JSON: spec.Unknown("cannot /kick banned or left users"), } } // TODO: should we be using SendLeave instead? @@ -174,7 +173,7 @@ func SendUnban( if body.UserID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing user_id"), + JSON: spec.BadJSON("missing user_id"), } } @@ -196,7 +195,7 @@ func SendUnban( if queryRes.Membership != spec.Ban { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("can only /unban users that are banned"), + JSON: spec.Unknown("can only /unban users that are banned"), } } // TODO: should we be using SendLeave instead? @@ -233,7 +232,7 @@ func SendInvite( if body.UserID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing user_id"), + JSON: spec.BadJSON("missing user_id"), } } @@ -263,7 +262,7 @@ func sendInvite( ) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError(), err + return spec.InternalServerError(), err } err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ @@ -277,12 +276,12 @@ func sendInvite( case roomserverAPI.ErrInvalidID: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(e.Error()), + JSON: spec.Unknown(e.Error()), }, e case roomserverAPI.ErrNotAllowed: return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(e.Error()), + JSON: spec.Forbidden(e.Error()), }, e case nil: default: @@ -290,7 +289,7 @@ func sendInvite( sentry.CaptureException(err) return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), }, err } @@ -377,7 +376,7 @@ func extractRequestData(req *http.Request) (body *threepid.MembershipRequest, ev if err != nil { resErr = &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } return } @@ -402,27 +401,27 @@ func checkAndProcessThreepid( if err == threepid.ErrMissingParameter { return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } } else if err == threepid.ErrNotTrusted { return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotTrusted(body.IDServer), + JSON: spec.NotTrusted(body.IDServer), } } else if err == eventutil.ErrRoomNoExists { return inviteStored, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(err.Error()), + JSON: spec.NotFound(err.Error()), } } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), } } if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") - er := jsonerror.InternalServerError() + er := spec.InternalServerError() return inviteStored, &er } return @@ -436,13 +435,13 @@ func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserver }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user") - e := jsonerror.InternalServerError() + e := spec.InternalServerError() return &e } if !membershipRes.IsInRoom { return &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("user does not belong to room"), + JSON: spec.Forbidden("user does not belong to room"), } } return nil @@ -462,18 +461,18 @@ func SendForget( err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) if err != nil { logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !membershipRes.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), + JSON: spec.Forbidden("room does not exist"), } } if membershipRes.IsInRoom { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(fmt.Sprintf("User %s is in room %s", device.UserID, roomID)), + JSON: spec.Unknown(fmt.Sprintf("User %s is in room %s", device.UserID, roomID)), } } @@ -484,7 +483,7 @@ func SendForget( response := roomserverAPI.PerformForgetResponse{} if err := rsAPI.PerformForget(ctx, &request, &response); err != nil { logger.WithError(err).Error("PerformForget: unable to forget room") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ Code: http.StatusOK, @@ -500,14 +499,14 @@ func getPowerlevels(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, if plEvent == nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You don't have permission to perform this action, no power_levels event in this room."), + JSON: spec.Forbidden("You don't have permission to perform this action, no power_levels event in this room."), } } pl, err := plEvent.PowerLevels() if err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You don't have permission to perform this action, the power_levels event for this room is malformed so auth checks cannot be performed."), + JSON: spec.Forbidden("You don't have permission to perform this action, the power_levels event for this room is malformed so auth checks cannot be performed."), } } return pl, nil diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go index f593e27db..8ac12ce5d 100644 --- a/clientapi/routing/notification.go +++ b/clientapi/routing/notification.go @@ -18,9 +18,9 @@ import ( "net/http" "strconv" - "github.com/matrix-org/dendrite/clientapi/jsonerror" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -35,7 +35,7 @@ func GetNotifications( limit, err = strconv.ParseInt(limitStr, 10, 64) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } @@ -43,7 +43,7 @@ func GetNotifications( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ Localpart: localpart, @@ -54,7 +54,7 @@ func GetNotifications( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications)) return util.JSONResponse{ diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go index 8e9be7889..1ead00eba 100644 --- a/clientapi/routing/openid.go +++ b/clientapi/routing/openid.go @@ -17,9 +17,9 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -43,7 +43,7 @@ func CreateOpenIDToken( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot request tokens for other users"), + JSON: spec.Forbidden("Cannot request tokens for other users"), } } @@ -55,7 +55,7 @@ func CreateOpenIDToken( err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.CreateOpenIDToken failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index f7f9da622..68466a77d 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -6,11 +6,11 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -90,7 +90,7 @@ func Password( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // Ask the user API to perform the password change. @@ -102,11 +102,11 @@ func Password( passwordRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !passwordRes.PasswordUpdated { util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // If the request asks us to log out all other devices then @@ -120,7 +120,7 @@ func Password( logoutRes := &api.PerformDeviceDeletionResponse{} if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } pushersReq := &api.PerformPusherDeletionRequest{ @@ -130,7 +130,7 @@ func Password( } if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 3937b9ad2..af486f6d7 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -18,7 +18,6 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrix" @@ -61,12 +60,12 @@ func PeekRoomByIDOrAlias( case roomserverAPI.ErrInvalidID: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(e.Error()), + JSON: spec.Unknown(e.Error()), } case roomserverAPI.ErrNotAllowed: return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(e.Error()), + JSON: spec.Forbidden(e.Error()), } case *gomatrix.HTTPError: return util.JSONResponse{ @@ -76,7 +75,7 @@ func PeekRoomByIDOrAlias( case nil: default: logrus.WithError(err).WithField("roomID", roomIDOrAlias).Errorf("Failed to peek room") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // if this user is already joined to the room, we let them peek anyway @@ -107,12 +106,12 @@ func UnpeekRoomByID( case roomserverAPI.ErrInvalidID: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(e.Error()), + JSON: spec.Unknown(e.Error()), } case nil: default: logrus.WithError(err).WithField("roomID", roomID).Errorf("Failed to un-peek room") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/presence.go b/clientapi/routing/presence.go index c50b09434..d915f0603 100644 --- a/clientapi/routing/presence.go +++ b/clientapi/routing/presence.go @@ -21,7 +21,6 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -54,7 +53,7 @@ func SetPresence( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Unable to set presence for other user."), + JSON: spec.Forbidden("Unable to set presence for other user."), } } var presence presenceReq @@ -67,7 +66,7 @@ func SetPresence( if !ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(fmt.Sprintf("Unknown presence '%s'.", presence.Presence)), + JSON: spec.Unknown(fmt.Sprintf("Unknown presence '%s'.", presence.Presence)), } } err := producer.SendPresence(req.Context(), userID, presenceStatus, presence.StatusMsg) @@ -75,7 +74,7 @@ func SetPresence( log.WithError(err).Errorf("failed to update presence") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } @@ -100,7 +99,7 @@ func GetPresence( log.WithError(err).Errorf("unable to get presence") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } @@ -119,7 +118,7 @@ func GetPresence( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 765ad7cbb..8e88e7c84 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -26,13 +26,11 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" ) @@ -49,12 +47,12 @@ func GetProfile( if err == appserviceAPI.ErrProfileNotExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), + JSON: spec.NotFound("The user does not exist or does not have a profile"), } } util.GetLogger(req.Context()).WithError(err).Error("getProfile failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -95,7 +93,7 @@ func SetAvatarURL( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -106,20 +104,20 @@ func SetAvatarURL( if r.AvatarURL == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("'avatar_url' must be supplied."), + JSON: spec.BadJSON("'avatar_url' must be supplied."), } } localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + JSON: spec.Forbidden("userID does not belong to a locally configured domain"), } } @@ -127,14 +125,14 @@ func SetAvatarURL( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } profile, changed, err := profileAPI.SetAvatarURL(req.Context(), localpart, domain, r.AvatarURL) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // No need to build new membership events, since nothing changed if !changed { @@ -184,7 +182,7 @@ func SetDisplayName( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -195,20 +193,20 @@ func SetDisplayName( if r.DisplayName == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("'displayname' must be supplied."), + JSON: spec.BadJSON("'displayname' must be supplied."), } } localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + JSON: spec.Forbidden("userID does not belong to a locally configured domain"), } } @@ -216,14 +214,14 @@ func SetDisplayName( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // No need to build new membership events, since nothing changed if !changed { @@ -256,13 +254,13 @@ func updateProfile( }, &res) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError(), err + return spec.InternalServerError(), err } _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError(), err + return spec.InternalServerError(), err } events, err := buildMembershipEvents( @@ -273,16 +271,16 @@ func updateProfile( case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), }, e default: util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError(), e + return spec.InternalServerError(), e } if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError(), err + return spec.InternalServerError(), err } return util.JSONResponse{}, nil } diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index 89ec824bf..2f51583fb 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -19,9 +19,9 @@ import ( "net/url" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -34,7 +34,7 @@ func GetPushers( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ Localpart: localpart, @@ -42,7 +42,7 @@ func GetPushers( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } for i := range queryRes.Pushers { queryRes.Pushers[i].SessionID = 0 @@ -63,7 +63,7 @@ func SetPusher( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } body := userapi.PerformPusherSetRequest{} if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil { @@ -99,7 +99,7 @@ func SetPusher( err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -111,6 +111,6 @@ func SetPusher( func invalidParam(msg string) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam(msg), + JSON: spec.InvalidParam(msg), } } diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go index f1a539adf..7be6d2a7e 100644 --- a/clientapi/routing/pushrules.go +++ b/clientapi/routing/pushrules.go @@ -7,17 +7,17 @@ import ( "net/http" "reflect" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/pushrules" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse { - if eerr, ok := err.(*jsonerror.MatrixError); ok { + if eerr, ok := err.(*spec.MatrixError); ok { var status int switch eerr.ErrCode { - case "M_INVALID_ARGUMENT_VALUE": + case "M_INVALID_PARAM": status = http.StatusBadRequest case "M_NOT_FOUND": status = http.StatusNotFound @@ -27,7 +27,7 @@ func errorResponse(ctx context.Context, err error, msg string, args ...interface return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err) } util.GetLogger(ctx).WithError(err).Errorf(msg, args...) - return jsonerror.InternalServerError() + return spec.InternalServerError() } func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { @@ -48,7 +48,7 @@ func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Devi } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } return util.JSONResponse{ Code: http.StatusOK, @@ -63,12 +63,12 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) // Even if rulesPtr is not nil, there may not be any rules for this kind if rulesPtr == nil || (rulesPtr != nil && len(*rulesPtr) == 0) { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } return util.JSONResponse{ Code: http.StatusOK, @@ -83,15 +83,15 @@ func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i < 0 { - return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + return errorResponse(ctx, spec.NotFound("push rule ID not found"), "pushRuleIndexByID failed") } return util.JSONResponse{ Code: http.StatusOK, @@ -104,14 +104,14 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, if err := json.NewDecoder(body).Decode(&newRule); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } } newRule.RuleID = ruleID errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule) if len(errs) > 0 { - return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs) + return errorResponse(ctx, spec.InvalidParam(errs[0].Error()), "rule sanity check failed: %v", errs) } ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) @@ -120,12 +120,12 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { // while this should be impossible (ValidateRule would already return an error), better keep it around - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i >= 0 && afterRuleID == "" && beforeRuleID == "" { @@ -172,15 +172,15 @@ func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, dev } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i < 0 { - return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + return errorResponse(ctx, spec.NotFound("push rule ID not found"), "pushRuleIndexByID failed") } *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) @@ -203,15 +203,15 @@ func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i < 0 { - return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + return errorResponse(ctx, spec.NotFound("push rule ID not found"), "pushRuleIndexByID failed") } return util.JSONResponse{ Code: http.StatusOK, @@ -226,7 +226,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } } if newPartialRule.Actions == nil { @@ -249,15 +249,15 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i < 0 { - return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + return errorResponse(ctx, spec.NotFound("push rule ID not found"), "pushRuleIndexByID failed") } if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) { @@ -313,7 +313,7 @@ func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) case "enabled": return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil default: - return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + return nil, spec.InvalidParam("invalid push rule attribute") } } @@ -324,7 +324,7 @@ func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) { case "enabled": return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil default: - return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + return nil, spec.InvalidParam("invalid push rule attribute") } } @@ -338,10 +338,10 @@ func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID strin } } if i == len(rules) { - return 0, jsonerror.NotFound("after: rule ID not found") + return 0, spec.NotFound("after: rule ID not found") } if rules[i].Default { - return 0, jsonerror.NotFound("after: rule ID must not be a default rule") + return 0, spec.NotFound("after: rule ID must not be a default rule") } // We stopped on the "after" match to differentiate // not-found from is-last-entry. Now we move to the earliest @@ -356,10 +356,10 @@ func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID strin } } if i == len(rules) { - return 0, jsonerror.NotFound("before: rule ID not found") + return 0, spec.NotFound("before: rule ID not found") } if rules[i].Default { - return 0, jsonerror.NotFound("before: rule ID must not be a default rule") + return 0, spec.NotFound("before: rule ID must not be a default rule") } } diff --git a/clientapi/routing/receipt.go b/clientapi/routing/receipt.go index 634b60b71..0bbb20b9d 100644 --- a/clientapi/routing/receipt.go +++ b/clientapi/routing/receipt.go @@ -20,7 +20,6 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/gomatrixserverlib/spec" @@ -49,7 +48,7 @@ func SetReceipt(req *http.Request, userAPI api.ClientUserAPI, syncProducer *prod case "m.fully_read": data, err := json.Marshal(fullyReadEvent{EventID: eventID}) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } dataReq := api.InputAccountDataRequest{ diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index a65cf673c..12391d266 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" @@ -63,13 +62,13 @@ func SendRedaction( if ev == nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.NotFound("unknown event ID"), // TODO: is it ok to leak existence? + JSON: spec.NotFound("unknown event ID"), // TODO: is it ok to leak existence? } } if ev.RoomID() != roomID { return util.JSONResponse{ Code: 400, - JSON: jsonerror.NotFound("cannot redact event in another room"), + JSON: spec.NotFound("cannot redact event in another room"), } } @@ -85,14 +84,14 @@ func SendRedaction( if plEvent == nil { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), + JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), } } pl, err := plEvent.PowerLevels() if err != nil { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden( + JSON: spec.Forbidden( "You don't have permission to redact this event, the power_levels event for this room is malformed so auth checks cannot be performed.", ), } @@ -102,7 +101,7 @@ func SendRedaction( if !allowedToRedact { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("You don't have permission to redact this event, power level too low."), + JSON: spec.Forbidden("You don't have permission to redact this event, power level too low."), } } @@ -122,12 +121,12 @@ func SendRedaction( err := proto.SetContent(r) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } var queryRes roomserverAPI.QueryLatestEventsAndStateResponse @@ -135,13 +134,13 @@ func SendRedaction( if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } } domain := device.UserDomain() if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*types.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") - return jsonerror.InternalServerError() + return spec.InternalServerError() } res := util.JSONResponse{ diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 35dd4846f..615ff2011 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -46,7 +46,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -428,7 +427,7 @@ func validateApplicationService( if matchedApplicationService == nil { return "", &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.UnknownToken("Supplied access_token does not match any known application service"), + JSON: spec.UnknownToken("Supplied access_token does not match any known application service"), } } @@ -439,7 +438,7 @@ func validateApplicationService( // If we didn't find any matches, return M_EXCLUSIVE return "", &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive(fmt.Sprintf( + JSON: spec.ASExclusive(fmt.Sprintf( "Supplied username %s did not match any namespaces for application service ID: %s", username, matchedApplicationService.ID)), } } @@ -448,7 +447,7 @@ func validateApplicationService( if UsernameMatchesMultipleExclusiveNamespaces(cfg, userID) { return "", &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive(fmt.Sprintf( + JSON: spec.ASExclusive(fmt.Sprintf( "Supplied username %s matches multiple exclusive application service namespaces. Only 1 match allowed", username)), } } @@ -474,7 +473,7 @@ func Register( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("Unable to read request body"), + JSON: spec.NotJSON("Unable to read request body"), } } @@ -518,7 +517,7 @@ func Register( if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + JSON: spec.InvalidUsername("Numeric user IDs are reserved"), } } // Auto generate a numeric username if r.Username is empty @@ -529,7 +528,7 @@ func Register( nres := &userapi.QueryNumericLocalpartResponse{} if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } r.Username = strconv.FormatInt(nres.ID, 10) } @@ -552,7 +551,7 @@ func Register( // type is not known or specified) return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("A known registration type (e.g. m.login.application_service) must be specified if an access_token is provided"), + JSON: spec.MissingParam("A known registration type (e.g. m.login.application_service) must be specified if an access_token is provided"), } default: // Spec-compliant case (neither the access_token nor the login type are @@ -590,7 +589,7 @@ func handleGuestRegistration( if !registrationEnabled || !guestsEnabled { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden( + JSON: spec.Forbidden( fmt.Sprintf("Guest registration is disabled on %q", r.ServerName), ), } @@ -604,7 +603,7 @@ func handleGuestRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create account: " + err.Error()), + JSON: spec.Unknown("failed to create account: " + err.Error()), } } token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ @@ -616,7 +615,7 @@ func handleGuestRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Failed to generate access token"), + JSON: spec.Unknown("Failed to generate access token"), } } //we don't allow guests to specify their own device_id @@ -632,7 +631,7 @@ func handleGuestRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create device: " + err.Error()), + JSON: spec.Unknown("failed to create device: " + err.Error()), } } return util.JSONResponse{ @@ -682,7 +681,7 @@ func handleRegistrationFlow( if !registrationEnabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden( + JSON: spec.Forbidden( fmt.Sprintf("Registration is disabled on %q", r.ServerName), ), } @@ -696,7 +695,7 @@ func handleRegistrationFlow( UsernameMatchesExclusiveNamespaces(cfg, r.Username) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive("This username is reserved by an application service."), + JSON: spec.ASExclusive("This username is reserved by an application service."), } } @@ -706,15 +705,15 @@ func handleRegistrationFlow( err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) switch err { case ErrCaptchaDisabled: - return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())} + return util.JSONResponse{Code: http.StatusForbidden, JSON: spec.Unknown(err.Error())} case ErrMissingResponse: - return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())} + return util.JSONResponse{Code: http.StatusBadRequest, JSON: spec.BadJSON(err.Error())} case ErrInvalidCaptcha: - return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())} + return util.JSONResponse{Code: http.StatusUnauthorized, JSON: spec.BadJSON(err.Error())} case nil: default: util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") - return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()} + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError()} } // Add Recaptcha to the list of completed registration stages @@ -732,7 +731,7 @@ func handleRegistrationFlow( default: return util.JSONResponse{ Code: http.StatusNotImplemented, - JSON: jsonerror.Unknown("unknown/unimplemented auth type"), + JSON: spec.Unknown("unknown/unimplemented auth type"), } } @@ -764,7 +763,7 @@ func handleApplicationServiceRegistration( if tokenErr != nil { return util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.MissingToken(tokenErr.Error()), + JSON: spec.MissingToken(tokenErr.Error()), } } @@ -834,14 +833,14 @@ func completeRegistration( if username == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Missing username"), + JSON: spec.MissingParam("Missing username"), } } // Blank passwords are only allowed by registered application services if password == "" && appserviceID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Missing password"), + JSON: spec.MissingParam("Missing password"), } } var accRes userapi.PerformAccountCreationResponse @@ -857,12 +856,12 @@ func completeRegistration( if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is already taken."), + JSON: spec.UserInUse("Desired user ID is already taken."), } } return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create account: " + err.Error()), + JSON: spec.Unknown("failed to create account: " + err.Error()), } } @@ -884,7 +883,7 @@ func completeRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Failed to generate access token"), + JSON: spec.Unknown("Failed to generate access token"), } } @@ -893,7 +892,7 @@ func completeRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to set display name: " + err.Error()), + JSON: spec.Unknown("failed to set display name: " + err.Error()), } } } @@ -911,7 +910,7 @@ func completeRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create device: " + err.Error()), + JSON: spec.Unknown("failed to create device: " + err.Error()), } } @@ -1006,7 +1005,7 @@ func RegisterAvailable( if v.ServerName == domain && !v.AllowRegistration { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden( + JSON: spec.Forbidden( fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)), ), } @@ -1023,7 +1022,7 @@ func RegisterAvailable( if appservice.OwnsNamespaceCoveringUserId(userID) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is reserved by an application service."), + JSON: spec.UserInUse("Desired user ID is reserved by an application service."), } } } @@ -1036,14 +1035,14 @@ func RegisterAvailable( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to check availability:" + err.Error()), + JSON: spec.Unknown("failed to check availability:" + err.Error()), } } if !res.Available { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired User ID is already taken."), + JSON: spec.UserInUse("Desired User ID is already taken."), } } @@ -1060,7 +1059,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if err != nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("malformed json: %s", err)), + JSON: spec.BadJSON(fmt.Sprintf("malformed json: %s", err)), } } valid, err := sr.IsValidMacLogin(ssrr.Nonce, ssrr.User, ssrr.Password, ssrr.Admin, ssrr.MacBytes) @@ -1070,7 +1069,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if !valid { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("bad mac"), + JSON: spec.Forbidden("bad mac"), } } // downcase capitals diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index b07f636dd..9a60f5314 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -28,7 +28,6 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -39,6 +38,7 @@ import ( "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" @@ -306,7 +306,7 @@ func Test_register(t *testing.T) { guestsDisabled: true, wantResponse: util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`), + JSON: spec.Forbidden(`Guest registration is disabled on "test"`), }, }, { @@ -318,7 +318,7 @@ func Test_register(t *testing.T) { loginType: "im.not.known", wantResponse: util.JSONResponse{ Code: http.StatusNotImplemented, - JSON: jsonerror.Unknown("unknown/unimplemented auth type"), + JSON: spec.Unknown("unknown/unimplemented auth type"), }, }, { @@ -326,7 +326,7 @@ func Test_register(t *testing.T) { registrationDisabled: true, wantResponse: util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(`Registration is disabled on "test"`), + JSON: spec.Forbidden(`Registration is disabled on "test"`), }, }, { @@ -344,7 +344,7 @@ func Test_register(t *testing.T) { username: "success", wantResponse: util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is already taken."), + JSON: spec.UserInUse("Desired user ID is already taken."), }, }, { @@ -361,7 +361,7 @@ func Test_register(t *testing.T) { username: "1337", wantResponse: util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + JSON: spec.InvalidUsername("Numeric user IDs are reserved"), }, }, { @@ -369,7 +369,7 @@ func Test_register(t *testing.T) { loginType: authtypes.LoginTypeRecaptcha, wantResponse: util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()), + JSON: spec.Unknown(ErrCaptchaDisabled.Error()), }, }, { @@ -378,7 +378,7 @@ func Test_register(t *testing.T) { loginType: authtypes.LoginTypeRecaptcha, wantResponse: util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(ErrMissingResponse.Error()), + JSON: spec.BadJSON(ErrMissingResponse.Error()), }, }, { @@ -388,7 +388,7 @@ func Test_register(t *testing.T) { captchaBody: `notvalid`, wantResponse: util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()), + JSON: spec.BadJSON(ErrInvalidCaptcha.Error()), }, }, { @@ -402,7 +402,7 @@ func Test_register(t *testing.T) { enableRecaptcha: true, loginType: authtypes.LoginTypeRecaptcha, captchaBody: `i should fail for other reasons`, - wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}, + wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError()}, }, } @@ -484,7 +484,7 @@ func Test_register(t *testing.T) { if !reflect.DeepEqual(r.Flows, cfg.Derived.Registration.Flows) { t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, cfg.Derived.Registration.Flows) } - case *jsonerror.MatrixError: + case *spec.MatrixError: if !reflect.DeepEqual(tc.wantResponse, resp) { t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) } @@ -541,7 +541,7 @@ func Test_register(t *testing.T) { resp = Register(req, userAPI, &cfg.ClientAPI) switch resp.JSON.(type) { - case *jsonerror.MatrixError: + case *spec.MatrixError: if !reflect.DeepEqual(tc.wantResponse, resp) { t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) } diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 92b9e6655..8802d22a4 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -19,10 +19,10 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -39,14 +39,14 @@ func GetTags( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot retrieve another user's tags"), + JSON: spec.Forbidden("Cannot retrieve another user's tags"), } } tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -71,7 +71,7 @@ func PutTag( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot modify another user's tags"), + JSON: spec.Forbidden("Cannot modify another user's tags"), } } @@ -83,7 +83,7 @@ func PutTag( tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if tagContent.Tags == nil { @@ -93,7 +93,7 @@ func PutTag( if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -118,14 +118,14 @@ func DeleteTag( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot modify another user's tags"), + JSON: spec.Forbidden("Cannot modify another user's tags"), } } tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // Check whether the tag to be deleted exists @@ -141,7 +141,7 @@ func DeleteTag( if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 70299e14d..2a2fa6655 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -33,7 +33,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth" clientutil "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" @@ -148,7 +147,7 @@ func Setup( } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("unknown method"), + JSON: spec.NotFound("unknown method"), } }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -659,7 +658,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("missing trailing slash"), + JSON: spec.InvalidParam("missing trailing slash"), } }), ).Methods(http.MethodGet, http.MethodOptions) @@ -674,7 +673,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"), + JSON: spec.InvalidParam("scope, kind and rule ID must be specified"), } }), ).Methods(http.MethodPut) @@ -693,7 +692,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"), + JSON: spec.InvalidParam("missing trailing slash after scope"), } }), ).Methods(http.MethodGet, http.MethodOptions) @@ -702,7 +701,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"), + JSON: spec.InvalidParam("kind and rule ID must be specified"), } }), ).Methods(http.MethodPut) @@ -721,7 +720,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"), + JSON: spec.InvalidParam("missing trailing slash after kind"), } }), ).Methods(http.MethodGet, http.MethodOptions) @@ -730,7 +729,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"), + JSON: spec.InvalidParam("rule ID must be specified"), } }), ).Methods(http.MethodPut) @@ -939,7 +938,7 @@ func Setup( // TODO: Allow people to peek into rooms. return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.GuestAccessForbidden("Guest access not implemented"), + JSON: spec.GuestAccessForbidden("Guest access not implemented"), } }), ).Methods(http.MethodGet, http.MethodOptions) @@ -1244,7 +1243,7 @@ func Setup( if version == "" { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), + JSON: spec.InvalidParam("version must be specified"), } } var reqBody keyBackupSessionRequest @@ -1265,7 +1264,7 @@ func Setup( if version == "" { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), + JSON: spec.InvalidParam("version must be specified"), } } roomID := vars["roomID"] @@ -1297,7 +1296,7 @@ func Setup( if version == "" { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), + JSON: spec.InvalidParam("version must be specified"), } } var reqBody userapi.KeyBackupSession diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 7e5918f2f..2e3cd4112 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -30,7 +30,6 @@ import ( "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/roomserver/api" @@ -81,7 +80,7 @@ func SendEvent( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), + JSON: spec.UnsupportedRoomVersion(err.Error()), } } @@ -126,7 +125,7 @@ func SendEvent( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } @@ -145,12 +144,12 @@ func SendEvent( if !aliasReq.Valid() { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam("Request contains invalid aliases."), + JSON: spec.InvalidParam("Request contains invalid aliases."), } } aliasRes := &api.GetAliasesForRoomIDResponse{} if err = rsAPI.GetAliasesForRoomID(req.Context(), &api.GetAliasesForRoomIDRequest{RoomID: roomID}, aliasRes); err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } var found int requestAliases := append(aliasReq.AltAliases, aliasReq.Alias) @@ -165,7 +164,7 @@ func SendEvent( if aliasReq.Alias != "" && found < len(requestAliases) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadAlias("No matching alias found."), + JSON: spec.BadAlias("No matching alias found."), } } } @@ -194,7 +193,7 @@ func SendEvent( false, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } timeToSubmitEvent := time.Since(startedSubmittingEvent) util.GetLogger(req.Context()).WithFields(logrus.Fields{ @@ -273,13 +272,13 @@ func generateSendEvent( err := proto.SetContent(r) if err != nil { util.GetLogger(ctx).WithError(err).Error("proto.SetContent failed") - resErr := jsonerror.InternalServerError() + resErr := spec.InternalServerError() return nil, &resErr } identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) if err != nil { - resErr := jsonerror.InternalServerError() + resErr := spec.InternalServerError() return nil, &resErr } @@ -288,27 +287,27 @@ func generateSendEvent( if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), } } else if e, ok := err.(gomatrixserverlib.EventValidationError); ok { if e.Code == gomatrixserverlib.EventValidationTooLarge { return nil, &util.JSONResponse{ Code: http.StatusRequestEntityTooLarge, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), } } return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), } } else if err != nil { util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed") - resErr := jsonerror.InternalServerError() + resErr := spec.InternalServerError() return nil, &resErr } @@ -321,7 +320,7 @@ func generateSendEvent( if err = gomatrixserverlib.Allowed(e.PDU, &provider); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? + JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? } } @@ -332,13 +331,13 @@ func generateSendEvent( util.GetLogger(ctx).WithError(err).Error("Cannot unmarshal the event content.") return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Cannot unmarshal the event content."), + JSON: spec.BadJSON("Cannot unmarshal the event content."), } } if content["replacement_room"] == e.RoomID() { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam("Cannot send tombstone event that points to the same room."), + JSON: spec.InvalidParam("Cannot send tombstone event that points to the same room."), } } } diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go index 0c0227937..6d4af0728 100644 --- a/clientapi/routing/sendtodevice.go +++ b/clientapi/routing/sendtodevice.go @@ -19,10 +19,10 @@ import ( "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/internal/transactions" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} @@ -53,7 +53,7 @@ func SendToDevice( req.Context(), device.UserID, userID, deviceID, eventType, message, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("eduProducer.SendToDevice failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } } diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 9dc884d62..17532a2dd 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -18,10 +18,10 @@ import ( "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type typingContentJSON struct { @@ -39,7 +39,7 @@ func SendTyping( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot set another user's typing state"), + JSON: spec.Forbidden("Cannot set another user's typing state"), } } @@ -58,7 +58,7 @@ func SendTyping( if err := syncProducer.SendTyping(req.Context(), userID, roomID, r.Typing, r.Timeout); err != nil { util.GetLogger(req.Context()).WithError(err).Error("eduProducer.Send failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index a9967adfe..a418677ea 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -32,12 +32,12 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Unspecced server notice request @@ -68,7 +68,7 @@ func SendServerNotice( if device.AccountType != userapi.AccountTypeAdmin { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("This API can only be used by admin users."), + JSON: spec.Forbidden("This API can only be used by admin users."), } } @@ -90,7 +90,7 @@ func SendServerNotice( if !r.valid() { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid request"), + JSON: spec.BadJSON("Invalid request"), } } @@ -175,7 +175,7 @@ func SendServerNotice( }} if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil { util.GetLogger(ctx).WithError(err).Error("saveTagData failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } default: @@ -189,7 +189,7 @@ func SendServerNotice( err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !membershipRes.IsInRoom { // re-invite the user @@ -237,7 +237,7 @@ func SendServerNotice( false, ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": e.EventID(), diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 705782e88..75abbda91 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -20,7 +20,6 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -57,12 +56,12 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a StateToFetch: []gomatrixserverlib.StateKeyTuple{}, }, &stateRes); err != nil { util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !stateRes.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), + JSON: spec.Forbidden("room does not exist"), } } @@ -74,7 +73,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a content := map[string]string{} if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if visibility, ok := content["history_visibility"]; ok { worldReadable = visibility == "world_readable" @@ -100,14 +99,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. if !membershipRes.HasBeenInRoom { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), + JSON: spec.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), } } // Otherwise, if the user has been in the room, whether or not we @@ -148,7 +147,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a }, &stateAfterRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return jsonerror.InternalServerError() + return spec.InternalServerError() } for _, ev := range stateAfterRes.StateEvents { stateEvents = append( @@ -203,7 +202,7 @@ func OnIncomingStateTypeRequest( StateToFetch: stateToFetch, }, &stateRes); err != nil { util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // Look at the room state and see if we have a history visibility event @@ -214,7 +213,7 @@ func OnIncomingStateTypeRequest( content := map[string]string{} if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if visibility, ok := content["history_visibility"]; ok { worldReadable = visibility == "world_readable" @@ -240,14 +239,14 @@ func OnIncomingStateTypeRequest( }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. if !membershipRes.HasBeenInRoom || membershipRes.Membership == spec.Ban { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), + JSON: spec.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), } } // Otherwise, if the user has been in the room, whether or not we @@ -295,7 +294,7 @@ func OnIncomingStateTypeRequest( }, &stateAfterRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if len(stateAfterRes.StateEvents) > 0 { event = stateAfterRes.StateEvents[0] @@ -307,7 +306,7 @@ func OnIncomingStateTypeRequest( if event == nil { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Cannot find state event for %q", evType)), + JSON: spec.NotFound(fmt.Sprintf("Cannot find state event for %q", evType)), } } diff --git a/clientapi/routing/thirdparty.go b/clientapi/routing/thirdparty.go index 7a62da449..0ee218556 100644 --- a/clientapi/routing/thirdparty.go +++ b/clientapi/routing/thirdparty.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/util" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Protocols implements @@ -33,13 +33,13 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev resp := &appserviceAPI.ProtocolResponse{} if err := asAPI.Protocols(req.Context(), &appserviceAPI.ProtocolRequest{Protocol: protocol}, resp); err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !resp.Exists { if protocol != "" { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The protocol is unknown."), + JSON: spec.NotFound("The protocol is unknown."), } } return util.JSONResponse{ @@ -71,12 +71,12 @@ func User(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, device * Protocol: protocol, Params: params.Encode(), }, resp); err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !resp.Exists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The Matrix User ID was not found"), + JSON: spec.NotFound("The Matrix User ID was not found"), } } return util.JSONResponse{ @@ -97,12 +97,12 @@ func Location(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, devi Protocol: protocol, Params: params.Encode(), }, resp); err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !resp.Exists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("No portal rooms were found."), + JSON: spec.NotFound("No portal rooms were found."), } } return util.JSONResponse{ diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 102b1d1cb..64fa59e40 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -19,12 +19,12 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -60,13 +60,13 @@ func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *co if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.QueryLocalpartForThreePID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if len(res.Localpart) > 0 { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MatrixError{ + JSON: spec.MatrixError{ ErrCode: "M_THREEPID_IN_USE", Err: userdb.Err3PIDInUse.Error(), }, @@ -77,11 +77,11 @@ func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *co if err == threepid.ErrNotTrusted { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotTrusted(body.IDServer), + JSON: spec.NotTrusted(body.IDServer), } } else if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.CreateSession failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -105,17 +105,17 @@ func CheckAndSave3PIDAssociation( if err == threepid.ErrNotTrusted { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotTrusted(body.Creds.IDServer), + JSON: spec.NotTrusted(body.Creds.IDServer), } } else if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !verified { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MatrixError{ + JSON: spec.MatrixError{ ErrCode: "M_THREEPID_AUTH_FAILED", Err: "Failed to auth 3pid", }, @@ -127,7 +127,7 @@ func CheckAndSave3PIDAssociation( err = threepid.PublishAssociation(req.Context(), body.Creds, device.UserID, cfg, client) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.PublishAssociation failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } @@ -135,7 +135,7 @@ func CheckAndSave3PIDAssociation( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ @@ -145,7 +145,7 @@ func CheckAndSave3PIDAssociation( Medium: medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -161,7 +161,7 @@ func GetAssociated3PIDs( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } res := &api.QueryThreePIDsForLocalpartResponse{} @@ -171,7 +171,7 @@ func GetAssociated3PIDs( }, res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -192,7 +192,7 @@ func Forget3PID(req *http.Request, threepidAPI api.ClientUserAPI) util.JSONRespo Medium: body.Medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.PerformForgetThreePID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go index f0936db1f..43f8d3e24 100644 --- a/clientapi/routing/upgrade_room.go +++ b/clientapi/routing/upgrade_room.go @@ -20,13 +20,13 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -55,7 +55,7 @@ func UpgradeRoom( if _, err := version.SupportedRoomVersion(gomatrixserverlib.RoomVersion(r.NewVersion)); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion("This server does not support that room version"), + JSON: spec.UnsupportedRoomVersion("This server does not support that room version"), } } @@ -65,16 +65,16 @@ func UpgradeRoom( case roomserverAPI.ErrNotAllowed: return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(e.Error()), + JSON: spec.Forbidden(e.Error()), } default: if errors.Is(err, eventutil.ErrRoomNoExists) { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } } - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go index f0f69ce3c..f3db0cbe9 100644 --- a/clientapi/routing/voip.go +++ b/clientapi/routing/voip.go @@ -25,9 +25,9 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // RequestTurnServer implements: @@ -60,7 +60,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mac.Write failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil)) diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index bd49c5301..beb648a48 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -37,12 +37,11 @@ type fedRoomserverAPI struct { } // PerformJoin will call this function -func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) error { +func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { if f.inputRoomEvents == nil { - return nil + return } f.inputRoomEvents(ctx, req, res) - return nil } // keychange consumer calls this diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 05488af61..81b61322c 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -21,7 +21,6 @@ import ( "strconv" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" @@ -50,7 +49,7 @@ func Backfill( if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Bad room ID: " + err.Error()), + JSON: spec.MissingParam("Bad room ID: " + err.Error()), } } @@ -65,14 +64,14 @@ func Backfill( if !exists { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("v is missing"), + JSON: spec.MissingParam("v is missing"), } } limit = httpReq.URL.Query().Get("limit") if len(limit) == 0 { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("limit is missing"), + JSON: spec.MissingParam("limit is missing"), } } @@ -92,14 +91,14 @@ func Backfill( util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("limit %q is invalid format", limit)), + JSON: spec.InvalidParam(fmt.Sprintf("limit %q is invalid format", limit)), } } // Query the roomserver. if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // Filter any event that's not from the requested room out. diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 6a2ef1527..318c0a349 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -16,7 +16,6 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" @@ -39,7 +38,7 @@ func GetUserDevices( } if res.Error != nil { util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } sigReq := &api.QuerySignaturesRequest{ @@ -51,9 +50,7 @@ func GetUserDevices( for _, dev := range res.Devices { sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) } - if err := keyAPI.QuerySignatures(req.Context(), sigReq, sigRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } + keyAPI.QuerySignatures(req.Context(), sigReq, sigRes) response := fclient.RespUserDevices{ UserID: userID, diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index da1e77d8f..ca279ac22 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -16,10 +16,10 @@ import ( "context" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -43,7 +43,7 @@ func GetEventAuth( } if event.RoomID() != roomID { - return util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} + return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) if resErr != nil { diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index efd64dce8..196a54db1 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -20,13 +20,11 @@ import ( "net/http" "time" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/roomserver/api" ) // GetEvent returns the requested event @@ -95,7 +93,7 @@ func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, if len(eventsResponse.Events) == 0 { return nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Event not found"), + JSON: spec.NotFound("Event not found"), } } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 0fcb64145..bdfe2c821 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -21,12 +21,12 @@ import ( "net/http" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -46,14 +46,14 @@ func InviteV2( case gomatrixserverlib.UnsupportedRoomVersionError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion( + JSON: spec.UnsupportedRoomVersion( fmt.Sprintf("Room version %q is not supported by this server.", e.Version), ), } case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } case nil: return processInvite( @@ -62,7 +62,7 @@ func InviteV2( default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into an invite request. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into an invite request. " + err.Error()), } } } @@ -85,13 +85,13 @@ func InviteV1( case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } case nil: default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into an invite v1 request. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into an invite v1 request. " + err.Error()), } } var strippedState []fclient.InviteV2StrippedState @@ -122,7 +122,7 @@ func processInvite( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion( + JSON: spec.UnsupportedRoomVersion( fmt.Sprintf("Room version %q is not supported by this server.", roomVer), ), } @@ -132,7 +132,7 @@ func processInvite( if event.RoomID() != roomID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The room ID in the request path must match the room ID in the invite event JSON"), + JSON: spec.BadJSON("The room ID in the request path must match the room ID in the invite event JSON"), } } @@ -140,14 +140,14 @@ func processInvite( if event.EventID() != eventID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event ID in the request path must match the event ID in the invite event JSON"), + JSON: spec.BadJSON("The event ID in the request path must match the event ID in the invite event JSON"), } } if event.StateKey() == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The invite event has no state key"), + JSON: spec.BadJSON("The invite event has no state key"), } } @@ -155,7 +155,7 @@ func processInvite( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)), + JSON: spec.InvalidParam(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)), } } @@ -164,14 +164,14 @@ func processInvite( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event JSON could not be redacted"), + JSON: spec.BadJSON("The event JSON could not be redacted"), } } _, serverName, err := gomatrixserverlib.SplitID('@', event.Sender()) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event JSON contains an invalid sender"), + JSON: spec.BadJSON("The event JSON contains an invalid sender"), } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ @@ -183,12 +183,12 @@ func processInvite( verifyResults, err := keys.VerifyJSONs(ctx, verifyRequests) if err != nil { util.GetLogger(ctx).WithError(err).Error("keys.VerifyJSONs failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if verifyResults[0].Error != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The invite must be signed by the server it originated on"), + JSON: spec.Forbidden("The invite must be signed by the server it originated on"), } } @@ -211,7 +211,7 @@ func processInvite( util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } @@ -219,12 +219,12 @@ func processInvite( case api.ErrInvalidID: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(e.Error()), + JSON: spec.Unknown(e.Error()), } case api.ErrNotAllowed: return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(e.Error()), + JSON: spec.Forbidden(e.Error()), } case nil: default: @@ -232,7 +232,7 @@ func processInvite( sentry.CaptureException(err) return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index c68a6f0cb..c301785cf 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -27,7 +27,6 @@ import ( "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" @@ -47,7 +46,7 @@ func MakeJoin( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } @@ -66,7 +65,7 @@ func MakeJoin( if !remoteSupportsVersion { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.IncompatibleRoomVersion(roomVersion), + JSON: spec.IncompatibleRoomVersion(string(roomVersion)), } } @@ -74,13 +73,13 @@ func MakeJoin( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid UserID"), + JSON: spec.BadJSON("Invalid UserID"), } } if domain != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The join must be sent by the server of the user"), + JSON: spec.Forbidden("The join must be sent by the server of the user"), } } @@ -92,18 +91,18 @@ func MakeJoin( inRoomRes := &api.QueryServerJoinedToRoomResponse{} if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), inRoomReq, inRoomRes); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !inRoomRes.RoomExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Room ID %q was not found on this server", roomID)), + JSON: spec.NotFound(fmt.Sprintf("Room ID %q was not found on this server", roomID)), } } if !inRoomRes.IsInRoom { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Room ID %q has no remaining users on this server", roomID)), + JSON: spec.NotFound(fmt.Sprintf("Room ID %q has no remaining users on this server", roomID)), } } @@ -112,7 +111,7 @@ func MakeJoin( res, authorisedVia, err := checkRestrictedJoin(httpReq, rsAPI, roomVersion, roomID, userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("checkRestrictedJoin failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } else if res != nil { return *res } @@ -130,14 +129,14 @@ func MakeJoin( } if err = proto.SetContent(content); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) if err != nil { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound( + JSON: spec.NotFound( fmt.Sprintf("Server name %q does not exist", request.Destination()), ), } @@ -150,16 +149,16 @@ func MakeJoin( if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), } } else if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // Check that the join is allowed or not @@ -172,7 +171,7 @@ func MakeJoin( if err = gomatrixserverlib.Allowed(event.PDU, &provider); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(err.Error()), + JSON: spec.Forbidden(err.Error()), } } @@ -202,14 +201,14 @@ func SendJoin( util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.UnsupportedRoomVersion( + JSON: spec.UnsupportedRoomVersion( fmt.Sprintf("QueryRoomVersionForRoom returned unknown room version: %s", roomVersion), ), } @@ -219,7 +218,7 @@ func SendJoin( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON: " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON: " + err.Error()), } } @@ -227,13 +226,13 @@ func SendJoin( if event.StateKey() == nil || event.StateKeyEquals("") { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("No state key was provided in the join event."), + JSON: spec.BadJSON("No state key was provided in the join event."), } } if !event.StateKeyEquals(event.Sender()) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Event state key must match the event sender."), + JSON: spec.BadJSON("Event state key must match the event sender."), } } @@ -244,12 +243,12 @@ func SendJoin( if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender of the join is invalid"), + JSON: spec.Forbidden("The sender of the join is invalid"), } } else if serverName != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"), + JSON: spec.Forbidden("The sender does not match the server that originated the request"), } } @@ -257,7 +256,7 @@ func SendJoin( if event.RoomID() != roomID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON( + JSON: spec.BadJSON( fmt.Sprintf( "The room ID in the request path (%q) must match the room ID in the join event JSON (%q)", roomID, event.RoomID(), @@ -270,7 +269,7 @@ func SendJoin( if event.EventID() != eventID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON( + JSON: spec.BadJSON( fmt.Sprintf( "The event ID in the request path (%q) must match the event ID in the join event JSON (%q)", eventID, event.EventID(), @@ -284,13 +283,13 @@ func SendJoin( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing content.membership key"), + JSON: spec.BadJSON("missing content.membership key"), } } if membership != spec.Join { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("membership must be 'join'"), + JSON: spec.BadJSON("membership must be 'join'"), } } @@ -300,7 +299,7 @@ func SendJoin( logrus.WithError(err).Errorf("XXX: join.go") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event JSON could not be redacted"), + JSON: spec.BadJSON("The event JSON could not be redacted"), } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ @@ -312,12 +311,12 @@ func SendJoin( verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if verifyResults[0].Error != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Signature check failed: " + verifyResults[0].Error.Error()), + JSON: spec.Forbidden("Signature check failed: " + verifyResults[0].Error.Error()), } } @@ -332,19 +331,19 @@ func SendJoin( }, &stateAndAuthChainResponse) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryStateAndAuthChain failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !stateAndAuthChainResponse.RoomExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } } if !stateAndAuthChainResponse.StateKnown { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("State not known"), + JSON: spec.Forbidden("State not known"), } } @@ -367,7 +366,7 @@ func SendJoin( if isBanned { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("user is banned"), + JSON: spec.Forbidden("user is banned"), } } @@ -377,7 +376,7 @@ func SendJoin( if err := json.Unmarshal(event.Content(), &memberContent); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } } if memberContent.AuthorisedVia != "" { @@ -385,13 +384,13 @@ func SendJoin( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("The authorising username %q is invalid.", memberContent.AuthorisedVia)), + JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q is invalid.", memberContent.AuthorisedVia)), } } if domain != cfg.Matrix.ServerName { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("The authorising username %q does not belong to this server.", memberContent.AuthorisedVia)), + JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q does not belong to this server.", memberContent.AuthorisedVia)), } } } @@ -410,7 +409,7 @@ func SendJoin( // the room, so set SendAsServer to cfg.Matrix.ServerName if !alreadyJoined { var response api.InputRoomEventsResponse - if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ + rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, @@ -419,18 +418,16 @@ func SendJoin( TransactionID: nil, }, }, - }, &response); err != nil { - return jsonerror.InternalAPIError(httpReq.Context(), err) - } + }, &response) if response.ErrMsg != "" { util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed") if response.NotAllowed { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Forbidden(response.ErrMsg), + JSON: spec.Forbidden(response.ErrMsg), } } - return jsonerror.InternalServerError() + return spec.InternalServerError() } } @@ -498,7 +495,7 @@ func checkRestrictedJoin( // instead. return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnableToAuthoriseJoin("This server cannot authorise the join."), + JSON: spec.UnableToAuthoriseJoin("This server cannot authorise the join."), }, "", nil case !res.Allowed: @@ -507,7 +504,7 @@ func checkRestrictedJoin( // and therefore can't join this room. return &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You are not joined to any matching rooms."), + JSON: spec.Forbidden("You are not joined to any matching rooms."), }, "", nil default: diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 6c30e5b06..d85de73d8 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -20,7 +20,6 @@ import ( "time" clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" @@ -46,7 +45,7 @@ func QueryDeviceKeys( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // make sure we only query users on our domain @@ -63,14 +62,12 @@ func QueryDeviceKeys( } var queryRes api.QueryKeysResponse - if err := keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ + keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ UserToDevices: qkr.DeviceKeys, - }, &queryRes); err != nil { - return jsonerror.InternalAPIError(httpReq.Context(), err) - } + }, &queryRes) if queryRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ Code: 200, @@ -100,7 +97,7 @@ func ClaimOneTimeKeys( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // make sure we only claim users on our domain @@ -117,14 +114,12 @@ func ClaimOneTimeKeys( } var claimRes api.PerformClaimKeysResponse - if err := keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ + keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ OneTimeKeys: cor.OneTimeKeys, - }, &claimRes); err != nil { - return jsonerror.InternalAPIError(httpReq.Context(), err) - } + }, &claimRes) if claimRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ Code: 200, @@ -205,7 +200,7 @@ func NotaryKeys( if !cfg.Matrix.IsLocalServerName(serverName) { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Server name not known"), + JSON: spec.NotFound("Server name not known"), } } @@ -248,7 +243,7 @@ func NotaryKeys( j, err := json.Marshal(keys) if err != nil { logrus.WithError(err).Errorf("Failed to marshal %q response", serverName) - return jsonerror.InternalServerError() + return spec.InternalServerError() } js, err := gomatrixserverlib.SignJSON( @@ -256,7 +251,7 @@ func NotaryKeys( ) if err != nil { logrus.WithError(err).Errorf("Failed to sign %q response", serverName) - return jsonerror.InternalServerError() + return spec.InternalServerError() } response.ServerKeys = append(response.ServerKeys, js) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index f4936d4ae..fdfbf15d7 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -17,7 +17,6 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" @@ -41,13 +40,13 @@ func MakeLeave( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid UserID"), + JSON: spec.BadJSON("Invalid UserID"), } } if domain != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The leave must be sent by the server of the user"), + JSON: spec.Forbidden("The leave must be sent by the server of the user"), } } @@ -61,14 +60,14 @@ func MakeLeave( err = proto.SetContent(map[string]interface{}{"membership": spec.Leave}) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("proto.SetContent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) if err != nil { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound( + JSON: spec.NotFound( fmt.Sprintf("Server name %q does not exist", request.Destination()), ), } @@ -79,16 +78,16 @@ func MakeLeave( if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), } } else if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // If the user has already left then just return their last leave @@ -118,7 +117,7 @@ func MakeLeave( if err = gomatrixserverlib.Allowed(event, &provider); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(err.Error()), + JSON: spec.Forbidden(err.Error()), } } @@ -145,7 +144,7 @@ func SendLeave( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), + JSON: spec.UnsupportedRoomVersion(err.Error()), } } @@ -153,7 +152,7 @@ func SendLeave( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.UnsupportedRoomVersion( + JSON: spec.UnsupportedRoomVersion( fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion), ), } @@ -165,13 +164,13 @@ func SendLeave( case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } case nil: default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } @@ -179,7 +178,7 @@ func SendLeave( if event.RoomID() != roomID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The room ID in the request path must match the room ID in the leave event JSON"), + JSON: spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON"), } } @@ -187,20 +186,20 @@ func SendLeave( if event.EventID() != eventID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event ID in the request path must match the event ID in the leave event JSON"), + JSON: spec.BadJSON("The event ID in the request path must match the event ID in the leave event JSON"), } } if event.StateKey() == nil || event.StateKeyEquals("") { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("No state key was provided in the leave event."), + JSON: spec.BadJSON("No state key was provided in the leave event."), } } if !event.StateKeyEquals(event.Sender()) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Event state key must match the event sender."), + JSON: spec.BadJSON("Event state key must match the event sender."), } } @@ -211,12 +210,12 @@ func SendLeave( if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender of the join is invalid"), + JSON: spec.Forbidden("The sender of the join is invalid"), } } else if serverName != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"), + JSON: spec.Forbidden("The sender does not match the server that originated the request"), } } @@ -234,7 +233,7 @@ func SendLeave( err = rsAPI.QueryLatestEventsAndState(httpReq.Context(), queryReq, queryRes) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryLatestEventsAndState failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // The room doesn't exist or we weren't ever joined to it. Might as well // no-op here. @@ -268,7 +267,7 @@ func SendLeave( logrus.WithError(err).Errorf("XXX: leave.go") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event JSON could not be redacted"), + JSON: spec.BadJSON("The event JSON could not be redacted"), } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ @@ -280,12 +279,12 @@ func SendLeave( verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if verifyResults[0].Error != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The leave must be signed by the server it originated on"), + JSON: spec.Forbidden("The leave must be signed by the server it originated on"), } } @@ -295,13 +294,13 @@ func SendLeave( util.GetLogger(httpReq.Context()).WithError(err).Error("event.Membership failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing content.membership key"), + JSON: spec.BadJSON("missing content.membership key"), } } if mem != spec.Leave { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The membership in the event content must be set to leave"), + JSON: spec.BadJSON("The membership in the event content must be set to leave"), } } @@ -309,7 +308,7 @@ func SendLeave( // We are responsible for notifying other servers that the user has left // the room, so set SendAsServer to cfg.Matrix.ServerName var response api.InputRoomEventsResponse - if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ + rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, @@ -318,19 +317,17 @@ func SendLeave( TransactionID: nil, }, }, - }, &response); err != nil { - return jsonerror.InternalAPIError(httpReq.Context(), err) - } + }, &response) if response.ErrMsg != "" { util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed") if response.NotAllowed { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Forbidden(response.ErrMsg), + JSON: spec.Forbidden(response.ErrMsg), } } - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index 23a99bf00..f8dd9e4f1 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -16,10 +16,10 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -42,7 +42,7 @@ func GetMissingEvents( if err := json.Unmarshal(request.Content(), &gme); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } @@ -63,7 +63,7 @@ func GetMissingEvents( &eventsResponse, ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryMissingEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } eventsResponse.Events = filterEvents(eventsResponse.Events, roomID) diff --git a/federationapi/routing/openid.go b/federationapi/routing/openid.go index cbc75a9a7..d28f319f5 100644 --- a/federationapi/routing/openid.go +++ b/federationapi/routing/openid.go @@ -18,8 +18,8 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -36,7 +36,7 @@ func GetOpenIDUserInfo( if len(token) == 0 { return util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.MissingArgument("access_token is missing"), + JSON: spec.MissingParam("access_token is missing"), } } @@ -55,7 +55,7 @@ func GetOpenIDUserInfo( nowMS := time.Now().UnixNano() / int64(time.Millisecond) if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresAtMS { code = http.StatusUnauthorized - res = jsonerror.UnknownToken("Access Token unknown or expired") + res = spec.UnknownToken("Access Token unknown or expired") } return util.JSONResponse{ diff --git a/federationapi/routing/peek.go b/federationapi/routing/peek.go index efc461464..9e924556f 100644 --- a/federationapi/routing/peek.go +++ b/federationapi/routing/peek.go @@ -17,12 +17,12 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -40,7 +40,7 @@ func Peek( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } @@ -58,7 +58,7 @@ func Peek( if !remoteSupportsVersion { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.IncompatibleRoomVersion(roomVersion), + JSON: spec.IncompatibleRoomVersion(string(roomVersion)), } } diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index 55641b216..7d6cfcaa6 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -18,10 +18,10 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -37,7 +37,7 @@ func GetProfile( if userID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("The request body did not contain required argument 'user_id'."), + JSON: spec.MissingParam("The request body did not contain required argument 'user_id'."), } } @@ -46,14 +46,14 @@ func GetProfile( util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), + JSON: spec.InvalidParam(fmt.Sprintf("Domain %q does not match this server", domain)), } } profile, err := userAPI.QueryProfile(httpReq.Context(), userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryProfile failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var res interface{} @@ -71,7 +71,7 @@ func GetProfile( } default: code = http.StatusBadRequest - res = jsonerror.InvalidArgumentValue("The request body did not contain an allowed value of argument 'field'. Allowed values are either: 'avatar_url', 'displayname'.") + res = spec.InvalidParam("The request body did not contain an allowed value of argument 'field'. Allowed values are either: 'avatar_url', 'displayname'.") } } else { res = eventutil.UserProfile{ diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index 80343d93a..59ff4eb2a 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -12,7 +12,6 @@ import ( "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" ) @@ -40,7 +39,7 @@ func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.FederationRoomser } response, err := publicRooms(req.Context(), request, rsAPI) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ Code: http.StatusOK, @@ -107,7 +106,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO // In that case, we want to assign 0 so we ignore the error if err != nil && len(httpReq.FormValue("limit")) > 0 { util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") - reqErr := jsonerror.InternalServerError() + reqErr := spec.InternalServerError() return &reqErr } request.Limit = int16(limit) @@ -119,7 +118,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO return &util.JSONResponse{ Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), + JSON: spec.NotFound("Bad method"), } } diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 61efd73fe..233290e2e 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -18,13 +18,13 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -40,14 +40,14 @@ func RoomAliasToID( if roomAlias == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Must supply room alias parameter."), + JSON: spec.BadJSON("Must supply room alias parameter."), } } _, domain, err := gomatrixserverlib.SplitID('#', roomAlias) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Room alias must be in the form '#localpart:domain'"), + JSON: spec.BadJSON("Room alias must be in the form '#localpart:domain'"), } } @@ -61,7 +61,7 @@ func RoomAliasToID( queryRes := &roomserverAPI.GetRoomIDForAliasResponse{} if err = rsAPI.GetRoomIDForAlias(httpReq.Context(), queryReq, queryRes); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if queryRes.RoomID != "" { @@ -69,7 +69,7 @@ func RoomAliasToID( var serverQueryRes federationAPI.QueryJoinedHostServerNamesInRoomResponse if err = senderAPI.QueryJoinedHostServerNamesInRoom(httpReq.Context(), &serverQueryReq, &serverQueryRes); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("senderAPI.QueryJoinedHostServerNamesInRoom failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } resp = fclient.RespDirectory{ @@ -80,7 +80,7 @@ func RoomAliasToID( // If no alias was found, return an error return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Room alias %s not found", roomAlias)), + JSON: spec.NotFound(fmt.Sprintf("Room alias %s not found", roomAlias)), } } } else { @@ -91,14 +91,14 @@ func RoomAliasToID( if x.Code == http.StatusNotFound { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room alias not found"), + JSON: spec.NotFound("Room alias not found"), } } } // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. util.GetLogger(httpReq.Context()).WithError(err).Error("federation.LookupRoomAlias failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 6ef544d06..f62a8f46c 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -23,7 +23,6 @@ import ( "github.com/getsentry/sentry-go" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/jsonerror" fedInternal "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" @@ -150,7 +149,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return InviteV1( @@ -166,7 +165,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return InviteV2( @@ -206,7 +205,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return GetState( @@ -221,7 +220,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return GetStateIDs( @@ -236,7 +235,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return GetEventAuth( @@ -279,7 +278,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -310,7 +309,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -341,7 +340,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -354,7 +353,7 @@ func Setup( body = []interface{}{ res.Code, res.JSON, } - jerr, ok := res.JSON.(*jsonerror.MatrixError) + jerr, ok := res.JSON.(*spec.MatrixError) if ok { body = jerr } @@ -373,7 +372,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -390,7 +389,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -407,7 +406,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -420,7 +419,7 @@ func Setup( body = []interface{}{ res.Code, res.JSON, } - jerr, ok := res.JSON.(*jsonerror.MatrixError) + jerr, ok := res.JSON.(*spec.MatrixError) if ok { body = jerr } @@ -439,7 +438,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -463,7 +462,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return GetMissingEvents(httpReq, request, rsAPI, vars["roomID"]) @@ -476,7 +475,7 @@ func Setup( if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return Backfill(httpReq, request, rsAPI, vars["roomID"], cfg) @@ -528,7 +527,7 @@ func ErrorIfLocalServerNotInRoom( if !joinedRes.IsInRoom { return &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("This server is not joined to room %s", roomID)), + JSON: spec.NotFound(fmt.Sprintf("This server is not joined to room %s", roomID)), } } return nil diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 04bf505a9..3c8e0cbef 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -25,12 +25,12 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) const ( @@ -104,7 +104,7 @@ func Send( if err := json.Unmarshal(request.Content(), &txnEvents); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. @@ -112,7 +112,7 @@ func Send( if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + JSON: spec.BadJSON("max 50 pdus / 100 edus"), } } diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index aa2cb2835..fa0e9351e 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -17,10 +17,10 @@ import ( "net/http" "net/url" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -88,7 +88,7 @@ func parseEventIDParam( if eventID == "" { resErr = &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("event_id missing"), + JSON: spec.MissingParam("event_id missing"), } } @@ -114,7 +114,7 @@ func getState( } if event.RoomID() != roomID { - return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} + return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) if resErr != nil { @@ -140,17 +140,17 @@ func getState( case !response.RoomExists: return nil, nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room not found"), + JSON: spec.NotFound("Room not found"), } case !response.StateKnown: return nil, nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("State not known"), + JSON: spec.NotFound("State not known"), } case response.IsRejected: return nil, nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Event not found"), + JSON: spec.NotFound("Event not found"), } } diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 244553ba9..adfafe740 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -22,17 +22,14 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" - - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" ) @@ -74,7 +71,7 @@ func CreateInvitesFrom3PIDInvites( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), + JSON: spec.UnsupportedRoomVersion(err.Error()), } } @@ -83,7 +80,7 @@ func CreateInvitesFrom3PIDInvites( ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("createInviteFrom3PIDInvite failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if event != nil { evs = append(evs, &types.HeaderedEvent{PDU: event}) @@ -103,7 +100,7 @@ func CreateInvitesFrom3PIDInvites( false, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ @@ -125,7 +122,7 @@ func ExchangeThirdPartyInvite( if err := json.Unmarshal(request.Content(), &proto); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } @@ -133,7 +130,7 @@ func ExchangeThirdPartyInvite( if proto.RoomID != roomID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The room ID in the request path must match the room ID in the invite event JSON"), + JSON: spec.BadJSON("The room ID in the request path must match the room ID in the invite event JSON"), } } @@ -141,7 +138,7 @@ func ExchangeThirdPartyInvite( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()), + JSON: spec.BadJSON("Invalid sender ID: " + err.Error()), } } @@ -150,7 +147,7 @@ func ExchangeThirdPartyInvite( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event's state key isn't a Matrix user ID"), + JSON: spec.BadJSON("The event's state key isn't a Matrix user ID"), } } @@ -158,7 +155,7 @@ func ExchangeThirdPartyInvite( if targetDomain != request.Origin() { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event's state key doesn't have the same domain as the request's origin"), + JSON: spec.BadJSON("The event's state key doesn't have the same domain as the request's origin"), } } @@ -166,7 +163,7 @@ func ExchangeThirdPartyInvite( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), + JSON: spec.UnsupportedRoomVersion(err.Error()), } } @@ -175,11 +172,11 @@ func ExchangeThirdPartyInvite( if err == errNotInRoom { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown room " + roomID), + JSON: spec.NotFound("Unknown room " + roomID), } } else if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // Ask the requesting server to sign the newly created event so we know it @@ -187,22 +184,22 @@ func ExchangeThirdPartyInvite( inviteReq, err := fclient.NewInviteV2Request(event, nil) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") - return jsonerror.InternalServerError() + return spec.InternalServerError() } signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Errorf("unknown room version: %s", roomVersion) - return jsonerror.InternalServerError() + return spec.InternalServerError() } inviteEvent, err := verImpl.NewEventFromUntrustedJSON(signedEvent.Event) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // Send the event to the roomserver @@ -219,7 +216,7 @@ func ExchangeThirdPartyInvite( false, ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/go.mod b/go.mod index c418faa4c..bd1d43fcb 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230504085954-69034410deb1 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230509222610-6fd532036ab6 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -42,10 +42,10 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.8.0 + golang.org/x/crypto v0.9.0 golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e - golang.org/x/term v0.7.0 + golang.org/x/term v0.8.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -125,8 +125,8 @@ require ( go.etcd.io/bbolt v1.3.6 // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.8.0 // indirect - golang.org/x/net v0.9.0 // indirect - golang.org/x/sys v0.7.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect golang.org/x/text v0.9.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.6.0 // indirect diff --git a/go.sum b/go.sum index 502dcea30..733d6e24f 100644 --- a/go.sum +++ b/go.sum @@ -323,18 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230428003202-267b4e79f138 h1:zqMuO/4ye8QnSPLhruxTC4cQcXfrvpPwdtT+4kqEgF4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230428003202-267b4e79f138/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230428142634-a4fa967eac17 h1:So8d7SZZdKB7+vWFXwmAQ3C+tUkkegMlcGk8n60w2og= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230428142634-a4fa967eac17/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230502101247-782aebf83205 h1:foJFr0V1uZC0oJ3ooenScGtLViq7Hx3rioe1Hf0lnhY= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230502101247-782aebf83205/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230502133856-ad26780a085c h1:5xXMu/08j8tWfiVUvD4yfs6mepz07BgC4kL2i0oGJX4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230502133856-ad26780a085c/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230503081352-9e29bff996eb h1:qg9iR39ctvB7A4hBcddjxmHQO/t3y4mpQnpmEc3xvNI= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230503081352-9e29bff996eb/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230504085954-69034410deb1 h1:K0wM4rUNdqzWVQ54am8IeQn1q6f03sTNvhUW+ZaK1Zs= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230504085954-69034410deb1/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230509222610-6fd532036ab6 h1:cF6fNfxC73fU9zT3pgzDXI9NDihAdnilqqGcpDWgNP4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230509222610-6fd532036ab6/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -521,8 +511,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= -golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -602,8 +592,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -678,12 +668,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.7.0 h1:BEvjmm5fURWqcfbSKTdpkDXYBrUS1c0m8agp14W48vQ= -golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= +golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 289d1d2ca..1966e7546 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -31,9 +31,9 @@ import ( "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // BasicAuth is used for authorization on /metrics handlers @@ -101,7 +101,7 @@ func MakeAuthAPI( if !opts.GuestAccessAllowed && device.AccountType == userapi.AccountTypeGuest { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.GuestAccessForbidden("Guest access not allowed"), + JSON: spec.GuestAccessForbidden("Guest access not allowed"), } } @@ -126,7 +126,7 @@ func MakeAdminAPI( if device.AccountType != userapi.AccountTypeAdmin { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("This API can only be used by admin users."), + JSON: spec.Forbidden("This API can only be used by admin users."), } } return f(req, device) diff --git a/internal/httputil/rate_limiting.go b/internal/httputil/rate_limiting.go index dab36481e..0b040d7f3 100644 --- a/internal/httputil/rate_limiting.go +++ b/internal/httputil/rate_limiting.go @@ -5,9 +5,9 @@ import ( "sync" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -118,7 +118,7 @@ func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSON // We hit the rate limit. Tell the client to back off. return &util.JSONResponse{ Code: http.StatusTooManyRequests, - JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), + JSON: spec.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), } } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 0d2503250..c9d321f25 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -21,7 +21,6 @@ import ( "sync" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/roomserver/api" @@ -153,7 +152,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut // See https://github.com/matrix-org/synapse/issues/7543 return nil, &util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON("PDU contains bad JSON"), + JSON: spec.BadJSON("PDU contains bad JSON"), } } util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index c884ebfca..fb30d410e 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -66,9 +66,8 @@ var ( type FakeRsAPI struct { rsAPI.RoomserverInternalAPI - shouldFailQuery bool - bannedFromRoom bool - shouldEventsFail bool + shouldFailQuery bool + bannedFromRoom bool } func (r *FakeRsAPI) QueryRoomVersionForRoom( @@ -98,11 +97,7 @@ func (r *FakeRsAPI) InputRoomEvents( ctx context.Context, req *rsAPI.InputRoomEventsRequest, res *rsAPI.InputRoomEventsResponse, -) error { - if r.shouldEventsFail { - return fmt.Errorf("Failure") - } - return nil +) { } func TestEmptyTransactionRequest(t *testing.T) { @@ -184,18 +179,6 @@ func TestProcessTransactionRequestPDUInvalidSignature(t *testing.T) { } } -func TestProcessTransactionRequestPDUSendFail(t *testing.T) { - keyRing := &test.NopJSONVerifier{} - txn := NewTxnReq(&FakeRsAPI{shouldEventsFail: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") - txnRes, jsonRes := txn.ProcessTransaction(context.Background()) - - assert.Nil(t, jsonRes) - assert.Equal(t, 1, len(txnRes.PDUs)) - for _, result := range txnRes.PDUs { - assert.NotEmpty(t, result.Error) - } -} - func createTransactionWithEDU(ctx *process.ProcessContext, edus []gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) { cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ @@ -659,12 +642,11 @@ func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *rsAPI.InputRoomEventsRequest, response *rsAPI.InputRoomEventsResponse, -) error { +) { t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) for _, ire := range request.InputRoomEvents { fmt.Println("InputRoomEvents: ", ire.Event.EventID()) } - return nil } // Query the latest events and state for a room from the room server. diff --git a/internal/validate.go b/internal/validate.go index f794d7a5b..99088f240 100644 --- a/internal/validate.go +++ b/internal/validate.go @@ -20,7 +20,6 @@ import ( "net/http" "regexp" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -58,12 +57,12 @@ func PasswordResponse(err error) *util.JSONResponse { case ErrPasswordWeak: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()), + JSON: spec.WeakPassword(ErrPasswordWeak.Error()), } case ErrPasswordTooLong: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()), + JSON: spec.BadJSON(ErrPasswordTooLong.Error()), } } return nil @@ -88,12 +87,12 @@ func UsernameResponse(err error) *util.JSONResponse { case ErrUsernameTooLong: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } case ErrUsernameInvalid, ErrUsernameUnderscore: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(err.Error()), + JSON: spec.InvalidUsername(err.Error()), } } return nil diff --git a/internal/validate_test.go b/internal/validate_test.go index 2244b7a96..e3a10178f 100644 --- a/internal/validate_test.go +++ b/internal/validate_test.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -22,13 +21,13 @@ func Test_validatePassword(t *testing.T) { name: "password too short", password: "shortpw", wantError: ErrPasswordWeak, - wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())}, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: spec.WeakPassword(ErrPasswordWeak.Error())}, }, { name: "password too long", password: strings.Repeat("a", maxPasswordLength+1), wantError: ErrPasswordTooLong, - wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())}, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: spec.BadJSON(ErrPasswordTooLong.Error())}, }, { name: "password OK", @@ -65,7 +64,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameInvalid, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + JSON: spec.InvalidUsername(ErrUsernameInvalid.Error()), }, }, { @@ -75,7 +74,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameInvalid, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + JSON: spec.InvalidUsername(ErrUsernameInvalid.Error()), }, }, { @@ -85,7 +84,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameTooLong, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()), + JSON: spec.BadJSON(ErrUsernameTooLong.Error()), }, }, { @@ -95,7 +94,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameUnderscore, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()), + JSON: spec.InvalidUsername(ErrUsernameUnderscore.Error()), }, }, { @@ -115,7 +114,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameInvalid, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + JSON: spec.InvalidUsername(ErrUsernameInvalid.Error()), }, }, { @@ -135,7 +134,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameInvalid, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + JSON: spec.InvalidUsername(ErrUsernameInvalid.Error()), }, }, { diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index bba24327b..e9f161a3c 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -30,7 +30,6 @@ import ( "sync" "unicode" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" @@ -130,7 +129,7 @@ func Download( // TODO: Handle the fact we might have started writing the response dReq.jsonErrorResponse(w, util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Failed to download: " + err.Error()), + JSON: spec.NotFound("Failed to download: " + err.Error()), }) return } @@ -138,7 +137,7 @@ func Download( if metadata == nil { dReq.jsonErrorResponse(w, util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("File not found"), + JSON: spec.NotFound("File not found"), }) return } @@ -168,7 +167,7 @@ func (r *downloadRequest) Validate() *util.JSONResponse { if !mediaIDRegex.MatchString(string(r.MediaMetadata.MediaID)) { return &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("mediaId must be a non-empty string using only characters in %v", mediaIDCharacters)), + JSON: spec.NotFound(fmt.Sprintf("mediaId must be a non-empty string using only characters in %v", mediaIDCharacters)), } } // Note: the origin will be validated either by comparison to the configured server name of this homeserver @@ -176,7 +175,7 @@ func (r *downloadRequest) Validate() *util.JSONResponse { if r.MediaMetadata.Origin == "" { return &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("serverName must be a non-empty string"), + JSON: spec.NotFound("serverName must be a non-empty string"), } } @@ -184,7 +183,7 @@ func (r *downloadRequest) Validate() *util.JSONResponse { if r.ThumbnailSize.Width <= 0 || r.ThumbnailSize.Height <= 0 { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("width and height must be greater than 0"), + JSON: spec.Unknown("width and height must be greater than 0"), } } // Default method to scale if not set @@ -194,7 +193,7 @@ func (r *downloadRequest) Validate() *util.JSONResponse { if r.ThumbnailSize.ResizeMethod != types.Crop && r.ThumbnailSize.ResizeMethod != types.Scale { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("method must be one of crop or scale"), + JSON: spec.Unknown("method must be one of crop or scale"), } } } diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 2175648ea..5061d4762 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -26,7 +26,6 @@ import ( "path" "strings" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" @@ -34,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -165,7 +165,7 @@ func (r *uploadRequest) doUpload( }).Warn("Error while transferring file") return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to upload"), + JSON: spec.Unknown("Failed to upload"), } } @@ -184,7 +184,7 @@ func (r *uploadRequest) doUpload( if err != nil { fileutils.RemoveDir(tmpDir, r.Logger) r.Logger.WithError(err).Error("Error querying the database by hash.") - resErr := jsonerror.InternalServerError() + resErr := spec.InternalServerError() return &resErr } if existingMetadata != nil { @@ -194,7 +194,7 @@ func (r *uploadRequest) doUpload( mediaID, merr := r.generateMediaID(ctx, db) if merr != nil { r.Logger.WithError(merr).Error("Failed to generate media ID for existing file") - resErr := jsonerror.InternalServerError() + resErr := spec.InternalServerError() return &resErr } @@ -217,7 +217,7 @@ func (r *uploadRequest) doUpload( if err != nil { fileutils.RemoveDir(tmpDir, r.Logger) r.Logger.WithError(err).Error("Failed to generate media ID for new upload") - resErr := jsonerror.InternalServerError() + resErr := spec.InternalServerError() return &resErr } } @@ -239,7 +239,7 @@ func (r *uploadRequest) doUpload( func requestEntityTooLargeJSONResponse(maxFileSizeBytes config.FileSizeBytes) *util.JSONResponse { return &util.JSONResponse{ Code: http.StatusRequestEntityTooLarge, - JSON: jsonerror.Unknown(fmt.Sprintf("HTTP Content-Length is greater than the maximum allowed upload size (%v).", maxFileSizeBytes)), + JSON: spec.Unknown(fmt.Sprintf("HTTP Content-Length is greater than the maximum allowed upload size (%v).", maxFileSizeBytes)), } } @@ -251,7 +251,7 @@ func (r *uploadRequest) Validate(maxFileSizeBytes config.FileSizeBytes) *util.JS if strings.HasPrefix(string(r.MediaMetadata.UploadName), "~") { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("File name must not begin with '~'."), + JSON: spec.Unknown("File name must not begin with '~'."), } } // TODO: Validate filename - what are the valid characters? @@ -264,7 +264,7 @@ func (r *uploadRequest) Validate(maxFileSizeBytes config.FileSizeBytes) *util.JS if _, _, err := gomatrixserverlib.SplitID('@', string(r.MediaMetadata.UserID)); err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("user id must be in the form @localpart:domain"), + JSON: spec.BadJSON("user id must be in the form @localpart:domain"), } } } @@ -290,7 +290,7 @@ func (r *uploadRequest) storeFileAndMetadata( r.Logger.WithError(err).Error("Failed to move file.") return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to upload"), + JSON: spec.Unknown("Failed to upload"), } } if duplicate { @@ -307,7 +307,7 @@ func (r *uploadRequest) storeFileAndMetadata( } return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to upload"), + JSON: spec.Unknown("Failed to upload"), } } diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go index 9a3ced529..2f3225b62 100644 --- a/relayapi/routing/relaytxn.go +++ b/relayapi/routing/relaytxn.go @@ -18,7 +18,6 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" @@ -40,13 +39,13 @@ func GetTransactionFromRelay( if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("invalid json provided"), + JSON: spec.BadJSON("invalid json provided"), } } if previousEntry.EntryID < 0 { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."), + JSON: spec.BadJSON("Invalid entry id provided. Must be >= 0."), } } logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go index 6140d0326..f6e556119 100644 --- a/relayapi/routing/routing.go +++ b/relayapi/routing/routing.go @@ -21,7 +21,6 @@ import ( "github.com/getsentry/sentry-go" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/httputil" relayInternal "github.com/matrix-org/dendrite/relayapi/internal" "github.com/matrix-org/dendrite/setup/config" @@ -59,7 +58,7 @@ func Setup( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username was invalid"), + JSON: spec.InvalidUsername("Username was invalid"), } } return SendTransactionToRelay( @@ -84,7 +83,7 @@ func Setup( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username was invalid"), + JSON: spec.InvalidUsername("Username was invalid"), } } return GetTransactionFromRelay(httpReq, request, relayAPI, *userID) diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go index 6ff08e205..4a742dede 100644 --- a/relayapi/routing/sendrelay.go +++ b/relayapi/routing/sendrelay.go @@ -18,7 +18,6 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" @@ -43,7 +42,7 @@ func SendTransactionToRelay( logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON." + err.Error()), } } @@ -52,7 +51,7 @@ func SendTransactionToRelay( if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + JSON: spec.BadJSON("max 50 pdus / 100 edus"), } } @@ -69,7 +68,7 @@ func SendTransactionToRelay( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("could not store the transaction for forwarding"), + JSON: spec.BadJSON("could not store the transaction for forwarding"), } } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 2aaecbbf4..ab1ec28f8 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -60,7 +60,7 @@ type InputRoomEventsAPI interface { ctx context.Context, req *InputRoomEventsRequest, res *InputRoomEventsResponse, - ) error + ) } // Query the latest events and state for a room from the room server. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index fc26a4740..2505a993b 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -104,9 +104,7 @@ func SendInputRoomEvents( VirtualHost: virtualHost, } var response InputRoomEventsResponse - if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { - return err - } + rsAPI.InputRoomEvents(ctx, &request, &response) return response.Err() } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 3e7ff7f7c..3db2d0a67 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -389,18 +389,18 @@ func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) error { +) { // Queue up the event into the roomserver. replySub, err := r.queueInputRoomEvents(ctx, request) if err != nil { response.ErrMsg = err.Error() - return nil + return } // If we aren't waiting for synchronous responses then we can // give up here, there is nothing further to do. if replySub == nil { - return nil + return } // Otherwise, we'll want to sit and wait for the responses @@ -412,14 +412,12 @@ func (r *Inputer) InputRoomEvents( msg, err := replySub.NextMsgWithContext(ctx) if err != nil { response.ErrMsg = err.Error() - return nil + return } if len(msg.Data) > 0 { response.ErrMsg = string(msg.Data) } } - - return nil } var roomserverInputBackpressure = prometheus.NewGaugeVec( diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index c8f5737ff..cd78b3722 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -893,5 +893,6 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r Asynchronous: true, // Needs to be async, as we otherwise create a deadlock } inputRes := &api.InputRoomEventsResponse{} - return r.InputRoomEvents(ctx, inputReq, inputRes) + r.InputRoomEvents(ctx, inputReq, inputRes) + return nil } diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 375eefbec..a539efd1d 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -141,8 +141,8 @@ func (r *Admin) PerformAdminEvacuateRoom( Asynchronous: true, } inputRes := &api.InputRoomEventsResponse{} - err = r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) - return affected, err + r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) + return affected, nil } // PerformAdminEvacuateUser will remove the given user from all rooms. @@ -334,9 +334,7 @@ func (r *Admin) PerformAdminDownloadState( SendAsServer: string(r.Cfg.Matrix.ServerName), }) - if err = r.Inputer.InputRoomEvents(ctx, inputReq, inputRes); err != nil { - return fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) - } + r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) if inputRes.ErrMsg != "" { return inputRes.Err() diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index db0b53fef..a3fa2e011 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -226,9 +226,7 @@ func (r *Inviter) PerformInvite( }, } inputRes := &api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { - return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) - } + r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) if err = inputRes.Err(); err != nil { logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") return nil, api.ErrNotAllowed{Err: err} diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index d676bd4bb..a836eb1ae 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -313,9 +313,7 @@ func (r *Joiner) performJoinRoomByID( }, } inputRes := rsAPI.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { - return "", "", rsAPI.ErrNotAllowed{Err: err} - } + r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = inputRes.Err(); err != nil { return "", "", rsAPI.ErrNotAllowed{Err: err} } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index f0e958112..e71b3e908 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -21,7 +21,6 @@ import ( "strings" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" @@ -115,7 +114,7 @@ func (r *Leaver) performLeaveRoomByID( // mimic the returned values from Synapse res.Message = "You cannot reject this invite" res.Code = 403 - return nil, jsonerror.LeaveServerNoticeError() + return nil, spec.LeaveServerNoticeError() } } } @@ -203,9 +202,7 @@ func (r *Leaver) performLeaveRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { - return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) - } + r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = inputRes.Err(); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 21726c4cd..e9d61fede 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -27,7 +27,6 @@ import ( "strings" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" @@ -169,7 +168,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP util.GetLogger(req.Context()).WithError(err).Error("failed to decode HTTP request as JSON") return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } rc := reqCtx{ @@ -201,7 +200,7 @@ func federatedEventRelationship( util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON") return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } rc := reqCtx{ @@ -268,7 +267,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo if event == nil || !rc.authorisedToSeeEvent(event) { return nil, &util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), + JSON: spec.Forbidden("Event does not exist or you are not authorised to see it"), } } rc.roomVersion = event.Version() @@ -428,7 +427,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent") - resErr := jsonerror.InternalServerError() + resErr := spec.InternalServerError() return nil, &resErr } var childEvents []*types.HeaderedEvent diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index b75400947..291e0f3b2 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -28,7 +28,6 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" @@ -96,7 +95,7 @@ func federatedSpacesHandler( if err != nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidParam("bad request uri"), + JSON: spec.InvalidParam("bad request uri"), } } @@ -214,13 +213,13 @@ func (w *walker) walk() util.JSONResponse { // CS API format return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("room is unknown/forbidden"), + JSON: spec.Forbidden("room is unknown/forbidden"), } } else { // SS API format return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("room is unknown/forbidden"), + JSON: spec.NotFound("room is unknown/forbidden"), } } } @@ -233,7 +232,7 @@ func (w *walker) walk() util.JSONResponse { if cache == nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("invalid from"), + JSON: spec.InvalidParam("invalid from"), } } } else { @@ -377,7 +376,7 @@ func (w *walker) walk() util.JSONResponse { if len(discoveredRooms) == 0 { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("room is unknown/forbidden"), + JSON: spec.NotFound("room is unknown/forbidden"), } } return util.JSONResponse{ diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 119549045..23c2ecbaa 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -34,20 +34,16 @@ func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *userapi.Perform func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages -func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) error { - return nil +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) { } func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *userapi.PerformDeleteKeysRequest, res *userapi.PerformDeleteKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) error { - return nil +func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) { } -func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) error { - return nil +func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) { } -func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) error { - return nil +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) { } func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { return nil @@ -60,8 +56,7 @@ func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.Query return nil } -func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) error { - return nil +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) { } type mockRoomserverAPI struct { diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index b85ab7f22..8ff656e7a 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -23,7 +23,6 @@ import ( "strconv" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" roomserver "github.com/matrix-org/dendrite/roomserver/api" @@ -57,7 +56,7 @@ func Context( ) util.JSONResponse { snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -75,7 +74,7 @@ func Context( } return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam(errMsg), + JSON: spec.InvalidParam(errMsg), Headers: nil, } } @@ -88,12 +87,12 @@ func Context( membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { logrus.WithError(err).Error("unable to query membership") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !membershipRes.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), + JSON: spec.Forbidden("room does not exist"), } } @@ -114,11 +113,11 @@ func Context( if err == sql.ErrNoRows { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Event %s not found", eventID)), + JSON: spec.NotFound(fmt.Sprintf("Event %s not found", eventID)), } } logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // verify the user is allowed to see the context for this room/event @@ -126,7 +125,7 @@ func Context( filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") - return jsonerror.InternalServerError() + return spec.InternalServerError() } logrus.WithFields(logrus.Fields{ "duration": time.Since(startTime), @@ -135,27 +134,27 @@ func Context( if len(filteredEvents) == 0 { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("User is not allowed to query context"), + JSON: spec.Forbidden("User is not allowed to query context"), } } eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch before events") - return jsonerror.InternalServerError() + return spec.InternalServerError() } _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch after events") - return jsonerror.InternalServerError() + return spec.InternalServerError() } startTime = time.Now() eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") - return jsonerror.InternalServerError() + return spec.InternalServerError() } logrus.WithFields(logrus.Fields{ @@ -167,7 +166,7 @@ func Context( state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to fetch current room state") - return jsonerror.InternalServerError() + return spec.InternalServerError() } eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll) @@ -181,7 +180,7 @@ func Context( newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { logrus.WithError(err).Error("unable to load membership events") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } diff --git a/syncapi/routing/filter.go b/syncapi/routing/filter.go index 266ad4adc..5152e1f81 100644 --- a/syncapi/routing/filter.go +++ b/syncapi/routing/filter.go @@ -23,11 +23,11 @@ import ( "github.com/matrix-org/util" "github.com/tidwall/gjson" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} @@ -37,13 +37,13 @@ func GetFilter( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot get filters for other users"), + JSON: spec.Forbidden("Cannot get filters for other users"), } } localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } filter := synctypes.DefaultFilter() @@ -53,7 +53,7 @@ func GetFilter( // even though it is not correct. return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotFound("No such filter"), + JSON: spec.NotFound("No such filter"), } } @@ -76,14 +76,14 @@ func PutFilter( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot create filters for other users"), + JSON: spec.Forbidden("Cannot create filters for other users"), } } localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var filter synctypes.Filter @@ -93,14 +93,14 @@ func PutFilter( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be read. " + err.Error()), + JSON: spec.BadJSON("The request body could not be read. " + err.Error()), } } if err = json.Unmarshal(body, &filter); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // the filter `limit` is `int` which defaults to 0 if not set which is not what we want. We want to use the default @@ -115,14 +115,14 @@ func PutFilter( if err = filter.Validate(); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid filter: " + err.Error()), + JSON: spec.BadJSON("Invalid filter: " + err.Error()), } } filterID, err := syncDB.PutFilter(req.Context(), localpart, &filter) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncDB.PutFilter failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } return util.JSONResponse{ diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 7775598ee..e3d77cc33 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -20,13 +20,13 @@ import ( "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // GetEvent implements @@ -51,13 +51,13 @@ func GetEvent( }) if err != nil { logger.WithError(err).Error("GetEvent: syncDB.NewDatabaseTransaction failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } events, err := db.Events(ctx, []string{eventID}) if err != nil { logger.WithError(err).Error("GetEvent: syncDB.Events failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } // The requested event does not exist in our database @@ -65,7 +65,7 @@ func GetEvent( logger.Debugf("GetEvent: requested event doesn't exist locally") return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + JSON: spec.NotFound("The event was not found or you do not have permission to read this event"), } } @@ -81,7 +81,7 @@ func GetEvent( logger.WithError(err).Error("GetEvent: internal.ApplyHistoryVisibilityFilter failed") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError(), } } @@ -91,7 +91,7 @@ func GetEvent( logger.WithField("event_count", len(events)).Debug("GetEvent: can't return the requested event") return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + JSON: spec.NotFound("The event was not found or you do not have permission to read this event"), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index a23f1525b..5a66009c8 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -19,13 +19,13 @@ import ( "math" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -67,26 +67,26 @@ func GetMemberships( var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !queryRes.HasBeenInRoom { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."), } } if joinedOnly && !queryRes.IsInRoom { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."), } } db, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } defer db.Rollback() // nolint: errcheck @@ -98,7 +98,7 @@ func GetMemberships( atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") - return jsonerror.InternalServerError() + return spec.InternalServerError() } } } @@ -106,13 +106,13 @@ func GetMemberships( eventIDs, err := db.SelectMemberships(req.Context(), roomID, atToken, membership, notMembership) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("db.SelectMemberships failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } qryRes := &api.QueryEventsByIDResponse{} if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } result := qryRes.Events @@ -124,7 +124,7 @@ func GetMemberships( var content databaseJoinedMember if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") - return jsonerror.InternalServerError() + return spec.InternalServerError() } res.Joined[ev.Sender()] = joinedMember(content) } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 54b72c64d..4d3c9e2eb 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -27,7 +27,6 @@ import ( "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -82,7 +81,7 @@ func OnIncomingMessagesRequest( // request that requires backfilling from the roomserver or federation. snapshot, err := db.NewDatabaseTransaction(req.Context()) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -90,19 +89,19 @@ func OnIncomingMessagesRequest( // check if the user has already forgotten about this room membershipResp, err := getMembershipForUser(req.Context(), roomID, device.UserID, rsAPI) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } if !membershipResp.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), + JSON: spec.Forbidden("room does not exist"), } } if membershipResp.IsRoomForgotten { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("user already forgot about this room"), + JSON: spec.Forbidden("user already forgot about this room"), } } @@ -110,7 +109,7 @@ func OnIncomingMessagesRequest( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("unable to parse filter"), + JSON: spec.InvalidParam("unable to parse filter"), } } @@ -132,7 +131,7 @@ func OnIncomingMessagesRequest( if dir != "b" && dir != "f" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"), + JSON: spec.MissingParam("Bad or missing dir query parameter (should be either 'b' or 'f')"), } } // A boolean is easier to handle in this case, especially since dir is sure @@ -145,14 +144,14 @@ func OnIncomingMessagesRequest( if streamToken, err = types.NewStreamTokenFromString(fromQuery); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()), + JSON: spec.InvalidParam("Invalid from parameter: " + err.Error()), } } else { fromStream = &streamToken from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) - return jsonerror.InternalServerError() + return spec.InternalServerError() } } } @@ -168,13 +167,13 @@ func OnIncomingMessagesRequest( if streamToken, err = types.NewStreamTokenFromString(toQuery); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()), + JSON: spec.InvalidParam("Invalid to parameter: " + err.Error()), } } else { to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) - return jsonerror.InternalServerError() + return spec.InternalServerError() } } } @@ -197,7 +196,7 @@ func OnIncomingMessagesRequest( if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Bad room ID: " + err.Error()), + JSON: spec.MissingParam("Bad room ID: " + err.Error()), } } @@ -233,7 +232,7 @@ func OnIncomingMessagesRequest( clientEvents, start, end, err := mReq.retrieveEvents() if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } util.GetLogger(req.Context()).WithFields(logrus.Fields{ @@ -254,7 +253,7 @@ func OnIncomingMessagesRequest( membershipEvents, err := applyLazyLoadMembers(req.Context(), device, snapshot, roomID, clientEvents, lazyLoadCache) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading") - return jsonerror.InternalServerError() + return spec.InternalServerError() } res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...) } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 55e4347d6..2bf11a566 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" @@ -30,6 +29,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type RelationsResponse struct { @@ -73,14 +73,14 @@ func Relations( if dir != "b" && dir != "f" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"), + JSON: spec.MissingParam("Bad or missing dir query parameter (should be either 'b' or 'f')"), } } snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { logrus.WithError(err).Error("Failed to get snapshot for relations") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 9607aa325..88c5c5045 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/httputil" @@ -158,12 +157,12 @@ func Setup( if !cfg.Fulltext.Enabled { return util.JSONResponse{ Code: http.StatusNotImplemented, - JSON: jsonerror.Unknown("Search has been disabled by the server administrator."), + JSON: spec.Unknown("Search has been disabled by the server administrator."), } } var nextBatch *string if err := req.ParseForm(); err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } if req.Form.Has("next_batch") { nb := req.FormValue("next_batch") diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 5f0373926..986284d06 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -29,7 +29,6 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/types" @@ -56,7 +55,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if from != nil && *from != "" { nextBatch, err = strconv.Atoi(*from) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } } @@ -66,7 +65,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -74,12 +73,12 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts // only search rooms the user is actually joined to joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join") if err != nil { - return jsonerror.InternalServerError() + return spec.InternalServerError() } if len(joinedRooms) == 0 { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("User not joined to any rooms."), + JSON: spec.NotFound("User not joined to any rooms."), } } joinedRoomsMap := make(map[string]struct{}, len(joinedRooms)) @@ -100,7 +99,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if len(rooms) == 0 { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Unknown("User not allowed to search in this room(s)."), + JSON: spec.Unknown("User not allowed to search in this room(s)."), } } @@ -116,7 +115,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ) if err != nil { logrus.WithError(err).Error("failed to search fulltext") - return jsonerror.InternalServerError() + return spec.InternalServerError() } logrus.Debugf("Search took %s", result.Took) @@ -156,7 +155,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts evs, err := syncDB.Events(ctx, wantEvents) if err != nil { logrus.WithError(err).Error("failed to get events from database") - return jsonerror.InternalServerError() + return spec.InternalServerError() } groups := make(map[string]RoomResult) @@ -174,12 +173,12 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq) if err != nil { logrus.WithError(err).Error("failed to get context events") - return jsonerror.InternalServerError() + return spec.InternalServerError() } startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter) if err != nil { logrus.WithError(err).Error("failed to get start/end") - return jsonerror.InternalServerError() + return spec.InternalServerError() } profileInfos := make(map[string]ProfileInfoResponse) @@ -222,7 +221,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to get current state") - return jsonerror.InternalServerError() + return spec.InternalServerError() } stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync) } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 6baaff3c8..09e5dee17 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -30,7 +30,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/sqlutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -232,12 +231,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. if err == types.ErrMalformedSyncToken { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(err.Error()), + JSON: spec.Unknown(err.Error()), } } @@ -517,32 +516,32 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use if from == "" || to == "" { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("missing ?from= or ?to="), + JSON: spec.InvalidParam("missing ?from= or ?to="), } } fromToken, err := types.NewStreamTokenFromString(from) if err != nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("bad 'from' value"), + JSON: spec.InvalidParam("bad 'from' value"), } } toToken, err := types.NewStreamTokenFromString(to) if err != nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("bad 'to' value"), + JSON: spec.InvalidParam("bad 'to' value"), } } syncReq, err := newSyncRequest(req, *device, rp.db) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed") - return jsonerror.InternalServerError() + return spec.InternalServerError() } snapshot, err := rp.db.NewDatabaseSnapshot(req.Context()) if err != nil { logrus.WithError(err).Error("Failed to acquire database snapshot for key change") - return jsonerror.InternalServerError() + return spec.InternalServerError() } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -553,7 +552,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("Failed to DeviceListCatchup info") - return jsonerror.InternalServerError() + return spec.InternalServerError() } succeeded = true return util.JSONResponse{ diff --git a/userapi/api/api.go b/userapi/api/api.go index 4e13a3b94..050402645 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -63,10 +63,10 @@ type FederationUserAPI interface { QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error) QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) } // api functions required by the sync api @@ -646,17 +646,17 @@ type QueryAccountByLocalpartResponse struct { // API functions required by the clientapi type ClientKeyAPI interface { UploadDeviceKeysAPI - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error } type UploadDeviceKeysAPI interface { - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) } // API functions required by the syncapi @@ -668,10 +668,10 @@ type SyncKeyAPI interface { type FederationKeyAPI interface { UploadDeviceKeysAPI - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) } // KeyError is returned if there was a problem performing/querying the server diff --git a/userapi/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go index 457a61838..9de866343 100644 --- a/userapi/consumers/signingkeyupdate.go +++ b/userapi/consumers/signingkeyupdate.go @@ -100,10 +100,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M UserID: updatePayload.UserID, } uploadRes := &api.PerformUploadDeviceKeysResponse{} - if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { - logrus.WithError(err).Error("failed to upload device keys") - return false - } + t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) if uploadRes.Error != nil { logrus.WithError(uploadRes.Error).Error("failed to upload device keys") return true diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index ea7b84f6b..be05841c4 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -105,7 +105,7 @@ func sanityCheckKey(key fclient.CrossSigningKey, userID string, purpose fclient. } // nolint:gocyclo -func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { // Find the keys to store. byPurpose := map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey{} toStore := types.CrossSigningKeyMap{} @@ -117,7 +117,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "Master key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return nil + return } byPurpose[fclient.CrossSigningKeyPurposeMaster] = req.MasterKey @@ -133,7 +133,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "Self-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return nil + return } byPurpose[fclient.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey @@ -148,7 +148,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "User-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return nil + return } byPurpose[fclient.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey @@ -163,7 +163,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "No keys were supplied in the request", IsMissingParam: true, } - return nil + return } // We can't have a self-signing or user-signing key without a master @@ -176,7 +176,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - return nil + return } // If we still can't find a master key for the user then stop the upload. @@ -187,7 +187,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "No master key was found", IsMissingParam: true, } - return nil + return } } @@ -214,7 +214,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. } } if !changed { - return nil + return } // Store the keys. @@ -222,7 +222,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } - return nil + return } // Now upload any signatures that were included with the keys. @@ -240,7 +240,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), } - return nil + return } } } @@ -257,18 +257,16 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. update.SelfSigningKey = &ssk } if update.MasterKey == nil && update.SelfSigningKey == nil { - return nil + return } if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } - return nil } - return nil } -func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) { // Before we do anything, we need the master and self-signing keys for this user. // Then we can verify the signatures make sense. queryReq := &api.QueryKeysRequest{ @@ -279,7 +277,7 @@ func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req for userID := range req.Signatures { queryReq.UserToDevices[userID] = []string{} } - _ = a.QueryKeys(ctx, queryReq, queryRes) + a.QueryKeys(ctx, queryReq, queryRes) selfSignatures := map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} otherSignatures := map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} @@ -325,14 +323,14 @@ func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req res.Error = &api.KeyError{ Err: fmt.Sprintf("a.processSelfSignatures: %s", err), } - return nil + return } if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.processOtherSignatures: %s", err), } - return nil + return } // Finally, generate a notification that we updated the signatures. @@ -348,10 +346,9 @@ func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } - return nil + return } } - return nil } func (a *UserInternalAPI) processSelfSignatures( @@ -524,7 +521,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( } } -func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { +func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) { for targetUserID, forTargetUser := range req.TargetIDs { keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil && err != sql.ErrNoRows { @@ -563,7 +560,7 @@ func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySig res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), } - return nil + return } for sourceUserID, forSourceUser := range sigMap { @@ -585,5 +582,4 @@ func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySig } } } - return nil } diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go index d60e522e8..3fccf56bb 100644 --- a/userapi/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -134,7 +134,7 @@ type DeviceListUpdaterDatabase interface { } type DeviceListUpdaterAPI interface { - PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error + PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) } // KeyChangeProducer is the interface for producers.KeyChange useful for testing. @@ -519,7 +519,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName sp uploadReq.SelfSigningKey = *res.SelfSigningKey } } - _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) } err = u.updateDeviceList(&res) if err != nil { diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 4d075e524..10b9c6521 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -125,8 +125,7 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys type mockDeviceListUpdaterAPI struct { } -func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { - return nil +func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { } type roundTripper struct { diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 0b188b091..786a2dcd8 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -63,7 +63,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor return nil } -func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { +func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // wrap request map in a top-level by-domain map @@ -110,7 +110,6 @@ func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.Perform if len(domainToDeviceKeys) > 0 { a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) } - return nil } func (a *UserInternalAPI) claimRemoteKeys( @@ -228,7 +227,7 @@ func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *a } // nolint:gocyclo -func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { +func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]fclient.CrossSigningKey) @@ -252,7 +251,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), } - return nil + return } // pull out display names after we have the keys so we handle wildcards correctly @@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil + return } logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue @@ -356,7 +355,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil + return } logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue @@ -384,7 +383,6 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque } } } - return nil } func (a *UserInternalAPI) remoteKeysFromDatabase( diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index ea97fd353..32f3d84b5 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -25,7 +25,6 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/gomatrixserverlib" @@ -715,7 +714,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform return res, fmt.Errorf("backup was deleted") } if version != req.Version { - return res, jsonerror.WrongBackupVersionError(version) + return res, spec.WrongBackupVersionError(version) } res.Exists = true res.Version = version From 67d68768574a234b733eb3e4061644fc098a69f6 Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 17 May 2023 00:33:27 +0000 Subject: [PATCH 03/35] Move MakeJoin logic to GMSL (#3081) --- clientapi/auth/auth.go | 6 +- clientapi/auth/login_test.go | 14 +- clientapi/auth/login_token.go | 6 +- clientapi/auth/user_interactive.go | 18 +- clientapi/clientapi_test.go | 3 +- clientapi/httputil/httputil.go | 6 +- clientapi/routing/account_data.go | 10 +- clientapi/routing/admin.go | 13 +- clientapi/routing/admin_whois.go | 5 +- clientapi/routing/aliases.go | 5 +- clientapi/routing/createroom.go | 57 ++- clientapi/routing/deactivate.go | 10 +- clientapi/routing/device.go | 35 +- clientapi/routing/directory.go | 45 ++- clientapi/routing/directory_public.go | 10 +- clientapi/routing/joined_rooms.go | 5 +- clientapi/routing/joinroom.go | 14 +- clientapi/routing/key_backup.go | 4 +- clientapi/routing/keys.go | 10 +- clientapi/routing/login.go | 10 +- clientapi/routing/logout.go | 10 +- clientapi/routing/membership.go | 56 ++- clientapi/routing/notification.go | 15 +- clientapi/routing/openid.go | 5 +- clientapi/routing/password.go | 25 +- clientapi/routing/peekroom.go | 10 +- clientapi/routing/presence.go | 6 +- clientapi/routing/profile.go | 45 ++- clientapi/routing/pusher.go | 20 +- clientapi/routing/pushrules.go | 13 +- clientapi/routing/receipt.go | 5 +- clientapi/routing/redaction.go | 18 +- clientapi/routing/register.go | 7 +- clientapi/routing/register_test.go | 11 +- clientapi/routing/room_tagging.go | 25 +- clientapi/routing/sendevent.go | 46 ++- clientapi/routing/sendtodevice.go | 5 +- clientapi/routing/sendtyping.go | 5 +- clientapi/routing/server_notices.go | 15 +- clientapi/routing/state.go | 40 +- clientapi/routing/thirdparty.go | 15 +- clientapi/routing/threepid.go | 63 +++- clientapi/routing/upgrade_room.go | 7 +- clientapi/routing/voip.go | 5 +- clientapi/threepid/invites.go | 34 +- clientapi/threepid/threepid.go | 5 +- federationapi/routing/backfill.go | 5 +- federationapi/routing/devices.go | 5 +- federationapi/routing/invite.go | 9 +- federationapi/routing/join.go | 344 ++++++++---------- federationapi/routing/keys.go | 20 +- federationapi/routing/leave.go | 35 +- federationapi/routing/missingevents.go | 5 +- federationapi/routing/peek.go | 2 +- federationapi/routing/profile.go | 5 +- federationapi/routing/publicrooms.go | 11 +- federationapi/routing/query.go | 15 +- federationapi/routing/routing.go | 27 +- federationapi/routing/threepid.go | 40 +- go.mod | 4 +- go.sum | 8 +- internal/eventutil/events.go | 14 +- internal/httputil/routing.go | 8 +- mediaapi/routing/upload.go | 18 +- relayapi/routing/routing.go | 2 +- roomserver/api/api.go | 7 + roomserver/internal/perform/perform_admin.go | 4 +- roomserver/internal/perform/perform_join.go | 6 +- .../internal/perform/perform_upgrade.go | 15 +- roomserver/internal/query/query.go | 43 +++ setup/mscs/msc2836/msc2836.go | 6 +- syncapi/routing/context.go | 45 ++- syncapi/routing/filter.go | 15 +- syncapi/routing/getevent.go | 12 +- syncapi/routing/memberships.go | 30 +- syncapi/routing/messages.go | 30 +- syncapi/routing/relations.go | 5 +- syncapi/routing/routing.go | 5 +- syncapi/routing/search.go | 40 +- syncapi/sync/requestpool.go | 15 +- 80 files changed, 1158 insertions(+), 494 deletions(-) diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index 479b9ac7b..8fae45b8d 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -68,8 +68,10 @@ func VerifyUserFromRequest( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") - jsonErr := spec.InternalServerError() - return nil, &jsonErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if res.Err != "" { if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index eb87d5e8e..93d3e2713 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -107,13 +107,13 @@ func TestBadLoginFromJSONReader(t *testing.T) { Name string Body string - WantErrCode string + WantErrCode spec.MatrixErrorCode }{ - {Name: "empty", WantErrCode: "M_BAD_JSON"}, + {Name: "empty", WantErrCode: spec.ErrorBadJSON}, { Name: "badUnmarshal", Body: `badsyntaxJSON`, - WantErrCode: "M_BAD_JSON", + WantErrCode: spec.ErrorBadJSON, }, { Name: "badPassword", @@ -123,7 +123,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { "password": "invalidpassword", "device_id": "adevice" }`, - WantErrCode: "M_FORBIDDEN", + WantErrCode: spec.ErrorForbidden, }, { Name: "badToken", @@ -132,7 +132,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { "token": "invalidtoken", "device_id": "adevice" }`, - WantErrCode: "M_FORBIDDEN", + WantErrCode: spec.ErrorForbidden, }, { Name: "badType", @@ -140,7 +140,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { "type": "m.login.invalid", "device_id": "adevice" }`, - WantErrCode: "M_INVALID_PARAM", + WantErrCode: spec.ErrorInvalidParam, }, } for _, tst := range tsts { @@ -157,7 +157,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { if errRes == nil { cleanup(ctx, nil) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) - } else if merr, ok := errRes.JSON.(*spec.MatrixError); ok && merr.ErrCode != tst.WantErrCode { + } else if merr, ok := errRes.JSON.(spec.MatrixError); ok && merr.ErrCode != tst.WantErrCode { t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) } }) diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go index 073f728d6..eb631481a 100644 --- a/clientapi/auth/login_token.go +++ b/clientapi/auth/login_token.go @@ -48,8 +48,10 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L var res uapi.QueryLoginTokenResponse if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed") - jsonErr := spec.InternalServerError() - return nil, nil, &jsonErr + return nil, nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if res.Data == nil { return nil, nil, &util.JSONResponse{ diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 58d34865f..92d83ad29 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -178,8 +178,10 @@ func (u *UserInteractive) NewSession() *util.JSONResponse { sessionID, err := GenerateAccessToken() if err != nil { logrus.WithError(err).Error("failed to generate session ID") - res := spec.InternalServerError() - return &res + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } u.Lock() u.Sessions[sessionID] = []string{} @@ -193,15 +195,19 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter mixedObjects := make(map[string]interface{}) b, err := json.Marshal(response) if err != nil { - ise := spec.InternalServerError() - return &ise + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } _ = json.Unmarshal(b, &mixedObjects) challenge := u.challenge(sessionID) b, err = json.Marshal(challenge.JSON) if err != nil { - ise := spec.InternalServerError() - return &ise + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } _ = json.Unmarshal(b, &mixedObjects) diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go index 2c34c1098..b339818a4 100644 --- a/clientapi/clientapi_test.go +++ b/clientapi/clientapi_test.go @@ -33,6 +33,7 @@ import ( uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" @@ -1105,7 +1106,7 @@ func Test3PID(t *testing.T) { resp := threepid.GetValidatedResponse{} switch r.URL.Query().Get("client_secret") { case "fail": - resp.ErrCode = "M_SESSION_NOT_VALIDATED" + resp.ErrCode = string(spec.ErrorSessionNotValidated) case "fail2": resp.ErrCode = "some other error" case "fail3": diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go index aea0c3db6..d9f442323 100644 --- a/clientapi/httputil/httputil.go +++ b/clientapi/httputil/httputil.go @@ -32,8 +32,10 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon body, err := io.ReadAll(req.Body) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("io.ReadAll failed") - resp := spec.InternalServerError() - return &resp + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return UnmarshalJSON(body, iface) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 572b28efb..7eacf9cc9 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -104,7 +104,10 @@ func SaveAccountData( body, err := io.ReadAll(req.Body) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("io.ReadAll failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !json.Valid(body) { @@ -157,7 +160,10 @@ func SaveReadMarker( if r.FullyRead != "" { data, err := json.Marshal(fullyReadEvent{EventID: r.FullyRead}) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } dataReq := api.InputAccountDataRequest{ diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 4d2cea681..8dd662a1b 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -31,7 +31,7 @@ func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAP } affected, err := rsAPI.PerformAdminEvacuateRoom(req.Context(), vars["roomID"]) - switch err { + switch err.(type) { case nil: case eventutil.ErrRoomNoExists: return util.JSONResponse{ @@ -113,7 +113,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De }, accAvailableResp); err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } if accAvailableResp.Available { @@ -169,7 +169,10 @@ func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) if err != nil { logrus.WithError(err).Error("failed to publish nats message") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, @@ -231,10 +234,10 @@ func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverA } } if err = rsAPI.PerformAdminDownloadState(req.Context(), roomID, device.UserID, spec.ServerName(serverName)); err != nil { - if errors.Is(err, eventutil.ErrRoomNoExists) { + if errors.Is(err, eventutil.ErrRoomNoExists{}) { return util.JSONResponse{ Code: 200, - JSON: spec.NotFound(eventutil.ErrRoomNoExists.Error()), + JSON: spec.NotFound(err.Error()), } } logrus.WithError(err).WithFields(logrus.Fields{ diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go index cb2b8a26b..7d7536564 100644 --- a/clientapi/routing/admin_whois.go +++ b/clientapi/routing/admin_whois.go @@ -61,7 +61,10 @@ func GetAdminWhois( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("GetAdminWhois failed to query user devices") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } devices := make(map[string]deviceInfo) diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index 87c1f9ffd..f6603be8b 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -62,7 +62,10 @@ func GetAliases( var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !queryRes.IsInRoom { return util.JSONResponse{ diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index f0cdd6f5a..bc9600060 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -174,7 +174,10 @@ func createRoom( _, userDomain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !cfg.Matrix.IsLocalServerName(userDomain) { return util.JSONResponse{ @@ -218,7 +221,10 @@ func createRoom( profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI) if err != nil { util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } createContent := map[string]interface{}{} @@ -342,7 +348,10 @@ func createRoom( err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) if err != nil { util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if aliasResp.RoomID != "" { return util.JSONResponse{ @@ -455,7 +464,10 @@ func createRoom( err = builder.SetContent(e.Content) if err != nil { util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if i > 0 { builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} @@ -463,17 +475,26 @@ func createRoom( var ev gomatrixserverlib.PDU if err = builder.AddAuthEvents(&authEvents); err != nil { util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } ev, err = builder.Build(evTime, userDomain, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildEvent failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Add the event to the list of auth events @@ -481,7 +502,10 @@ func createRoom( err = authEvents.AddEvent(ev) if err != nil { util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -496,7 +520,10 @@ func createRoom( } if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // TODO(#269): Reserve room alias while we create the room. This stops us @@ -513,7 +540,10 @@ func createRoom( err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) if err != nil { util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if aliasResp.AliasExists { @@ -596,7 +626,7 @@ func createRoom( sentry.CaptureException(err) return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } } @@ -609,7 +639,10 @@ func createRoom( Visibility: spec.Public, }); err != nil { util.GetLogger(ctx).WithError(err).Error("failed to publish room") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } diff --git a/clientapi/routing/deactivate.go b/clientapi/routing/deactivate.go index 78cf9fe38..c151c130a 100644 --- a/clientapi/routing/deactivate.go +++ b/clientapi/routing/deactivate.go @@ -36,7 +36,10 @@ func Deactivate( localpart, serverName, err := gomatrixserverlib.SplitID('@', login.Username()) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var res api.PerformAccountDeactivationResponse @@ -46,7 +49,10 @@ func Deactivate( }, &res) if err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformAccountDeactivation failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 6209d8e95..6f2de3539 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -60,7 +60,10 @@ func GetDeviceByID( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var targetDevice *api.Device for _, device := range queryRes.Devices { @@ -97,7 +100,10 @@ func GetDevicesByLocalpart( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res := devicesJSON{} @@ -139,7 +145,10 @@ func UpdateDeviceByID( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !performRes.DeviceExists { return util.JSONResponse{ @@ -206,7 +215,10 @@ func DeleteDeviceById( localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // make sure that the access token being used matches the login creds used for user interactive auth, else @@ -224,7 +236,10 @@ func DeleteDeviceById( DeviceIDs: []string{deviceID}, }, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } deleteOK = true @@ -266,7 +281,10 @@ func DeleteDevices( payload := devicesDeleteJSON{} if err = json.Unmarshal(bodyBytes, &payload); err != nil { util.GetLogger(ctx).WithError(err).Error("unable to unmarshal device deletion request") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var res api.PerformDeviceDeletionResponse @@ -275,7 +293,10 @@ func DeleteDevices( DeviceIDs: payload.Devices, }, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 0ca9475d7..c786f8cc4 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -69,7 +69,10 @@ func DirectoryRoom( queryRes := &roomserverAPI.GetRoomIDForAliasResponse{} if err = rsAPI.GetRoomIDForAlias(req.Context(), queryReq, queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.GetRoomIDForAlias failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res.RoomID = queryRes.RoomID @@ -83,7 +86,10 @@ func DirectoryRoom( // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. util.GetLogger(req.Context()).WithError(fedErr).Error("federation.LookupRoomAlias failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res.RoomID = fedRes.RoomID res.fillServers(fedRes.Servers) @@ -102,7 +108,10 @@ func DirectoryRoom( var joinedHostsRes federationAPI.QueryJoinedHostServerNamesInRoomResponse if err = fedSenderAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &joinedHostsReq, &joinedHostsRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("fedSenderAPI.QueryJoinedHostServerNamesInRoom failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res.fillServers(joinedHostsRes.ServerNames) } @@ -180,7 +189,10 @@ func SetLocalAlias( var queryRes roomserverAPI.SetRoomAliasResponse if err := rsAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if queryRes.AliasExists { @@ -210,7 +222,10 @@ func RemoveLocalAlias( var queryRes roomserverAPI.RemoveRoomAliasResponse if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.RemoveRoomAlias failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !queryRes.Found { @@ -248,7 +263,10 @@ func GetVisibility( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPublishedRooms failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var v roomVisibility @@ -286,7 +304,10 @@ func SetVisibility( err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) if err != nil || len(queryEventsRes.StateEvents) == 0 { util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event @@ -308,7 +329,10 @@ func SetVisibility( Visibility: v.Visibility, }); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to publish room") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -344,7 +368,10 @@ func SetVisibilityAS( AppserviceID: dev.AppserviceID, }); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to publish room") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 9718ccab6..67146630c 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -81,7 +81,10 @@ func GetPostPublicRooms( ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to get public rooms") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, @@ -92,7 +95,10 @@ func GetPostPublicRooms( response, err := publicRooms(req.Context(), request, rsAPI, extRoomsProvider) if err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to work out public rooms") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/joined_rooms.go b/clientapi/routing/joined_rooms.go index 51a96e4d9..f664183f8 100644 --- a/clientapi/routing/joined_rooms.go +++ b/clientapi/routing/joined_rooms.go @@ -40,7 +40,10 @@ func GetJoinedRooms( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if res.RoomIDs == nil { res.RoomIDs = []string{} diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index a67d51327..43331b42a 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -16,7 +16,6 @@ package routing import ( "encoding/json" - "errors" "net/http" "time" @@ -114,16 +113,15 @@ func JoinRoomByIDOrAlias( Code: e.Code, JSON: json.RawMessage(e.Message), } + case eventutil.ErrRoomNoExists: + response = util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(e.Error()), + } default: response = util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), - } - if errors.Is(err, eventutil.ErrRoomNoExists) { - response = util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound(e.Error()), - } + JSON: spec.InternalServerError{}, } } done <- response diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index b7b1cadd2..7f8bd9f40 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -128,7 +128,7 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse Algorithm: kb.Algorithm, }) switch e := err.(type) { - case *spec.ErrRoomKeysVersion: + case spec.ErrRoomKeysVersion: return util.JSONResponse{ Code: http.StatusForbidden, JSON: e, @@ -182,7 +182,7 @@ func UploadBackupKeys( }) switch e := err.(type) { - case *spec.ErrRoomKeysVersion: + case spec.ErrRoomKeysVersion: return util.JSONResponse{ Code: http.StatusForbidden, JSON: e, diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 363ae3dc9..72785cda8 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -67,7 +67,10 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) } if uploadRes.Error != nil { util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if len(uploadRes.KeyErrors) > 0 { util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys") @@ -156,7 +159,10 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse { }, &claimRes) if claimRes.Error != nil { util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: 200, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index d326bff7f..bc38b8340 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -83,13 +83,19 @@ func completeAuth( token, err := auth.GenerateAccessToken() if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.GenerateAccessToken failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } localpart, serverName, err := userutil.ParseUsernameParam(login.Username(), cfg) if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.ParseUsernameParam failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var performRes userapi.PerformDeviceCreationResponse diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 049c88d57..d06bac784 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -33,7 +33,10 @@ func Logout( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -53,7 +56,10 @@ func LogoutAll( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 9b95ba5d8..4f2a0e394 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -85,7 +85,10 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic ) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } serverName := device.UserDomain() @@ -100,7 +103,10 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic false, ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -262,7 +268,10 @@ func sendInvite( ) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") - return spec.InternalServerError(), err + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, err } err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ @@ -289,7 +298,7 @@ func sendInvite( sentry.CaptureException(err) return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, }, err } @@ -398,31 +407,38 @@ func checkAndProcessThreepid( req.Context(), device, body, cfg, rsAPI, profileAPI, roomID, evTime, ) - if err == threepid.ErrMissingParameter { + switch e := err.(type) { + case nil: + case threepid.ErrMissingParameter: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON(err.Error()), } - } else if err == threepid.ErrNotTrusted { + case threepid.ErrNotTrusted: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.NotTrusted(body.IDServer), } - } else if err == eventutil.ErrRoomNoExists { + case eventutil.ErrRoomNoExists: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") return inviteStored, &util.JSONResponse{ Code: http.StatusNotFound, JSON: spec.NotFound(err.Error()), } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + case gomatrixserverlib.BadJSONError: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON(e.Error()), } - } - if err != nil { + default: util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") - er := spec.InternalServerError() - return inviteStored, &er + return inviteStored, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return } @@ -435,8 +451,10 @@ func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserver }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user") - e := spec.InternalServerError() - return &e + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipRes.IsInRoom { return &util.JSONResponse{ @@ -461,7 +479,10 @@ func SendForget( err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) if err != nil { logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipRes.RoomExists { return util.JSONResponse{ @@ -483,7 +504,10 @@ func SendForget( response := roomserverAPI.PerformForgetResponse{} if err := rsAPI.PerformForget(ctx, &request, &response); err != nil { logger.WithError(err).Error("PerformForget: unable to forget room") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go index 8ac12ce5d..4b9043faa 100644 --- a/clientapi/routing/notification.go +++ b/clientapi/routing/notification.go @@ -35,7 +35,10 @@ func GetNotifications( limit, err = strconv.ParseInt(limitStr, 10, 64) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -43,7 +46,10 @@ func GetNotifications( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ Localpart: localpart, @@ -54,7 +60,10 @@ func GetNotifications( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications)) return util.JSONResponse{ diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go index 1ead00eba..8dfba8af9 100644 --- a/clientapi/routing/openid.go +++ b/clientapi/routing/openid.go @@ -55,7 +55,10 @@ func CreateOpenIDToken( err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.CreateOpenIDToken failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 68466a77d..24c52b06d 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -90,7 +90,10 @@ func Password( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Ask the user API to perform the password change. @@ -102,11 +105,17 @@ func Password( passwordRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !passwordRes.PasswordUpdated { util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // If the request asks us to log out all other devices then @@ -120,7 +129,10 @@ func Password( logoutRes := &api.PerformDeviceDeletionResponse{} if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } pushersReq := &api.PerformPusherDeletionRequest{ @@ -130,7 +142,10 @@ func Password( } if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index af486f6d7..772dc8477 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -75,7 +75,10 @@ func PeekRoomByIDOrAlias( case nil: default: logrus.WithError(err).WithField("roomID", roomIDOrAlias).Errorf("Failed to peek room") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // if this user is already joined to the room, we let them peek anyway @@ -111,7 +114,10 @@ func UnpeekRoomByID( case nil: default: logrus.WithError(err).WithField("roomID", roomID).Errorf("Failed to un-peek room") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/presence.go b/clientapi/routing/presence.go index d915f0603..5aa6d8dd2 100644 --- a/clientapi/routing/presence.go +++ b/clientapi/routing/presence.go @@ -74,7 +74,7 @@ func SetPresence( log.WithError(err).Errorf("failed to update presence") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } @@ -99,7 +99,7 @@ func GetPresence( log.WithError(err).Errorf("unable to get presence") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } @@ -118,7 +118,7 @@ func GetPresence( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 8e88e7c84..76129f0a8 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -52,7 +52,10 @@ func GetProfile( } util.GetLogger(req.Context()).WithError(err).Error("getProfile failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -111,7 +114,10 @@ func SetAvatarURL( localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !cfg.Matrix.IsLocalServerName(domain) { @@ -132,7 +138,10 @@ func SetAvatarURL( profile, changed, err := profileAPI.SetAvatarURL(req.Context(), localpart, domain, r.AvatarURL) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // No need to build new membership events, since nothing changed if !changed { @@ -200,7 +209,10 @@ func SetDisplayName( localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !cfg.Matrix.IsLocalServerName(domain) { @@ -221,7 +233,10 @@ func SetDisplayName( profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // No need to build new membership events, since nothing changed if !changed { @@ -254,13 +269,19 @@ func updateProfile( }, &res) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") - return spec.InternalServerError(), err + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, err } _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError(), err + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, err } events, err := buildMembershipEvents( @@ -275,12 +296,18 @@ func updateProfile( }, e default: util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed") - return spec.InternalServerError(), e + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, e } if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return spec.InternalServerError(), err + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, err } return util.JSONResponse{}, nil } diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index 2f51583fb..ed59129cc 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -34,7 +34,10 @@ func GetPushers( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ Localpart: localpart, @@ -42,7 +45,10 @@ func GetPushers( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } for i := range queryRes.Pushers { queryRes.Pushers[i].SessionID = 0 @@ -63,7 +69,10 @@ func SetPusher( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } body := userapi.PerformPusherSetRequest{} if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil { @@ -99,7 +108,10 @@ func SetPusher( err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go index 7be6d2a7e..74873d5c9 100644 --- a/clientapi/routing/pushrules.go +++ b/clientapi/routing/pushrules.go @@ -14,20 +14,23 @@ import ( ) func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse { - if eerr, ok := err.(*spec.MatrixError); ok { + if eerr, ok := err.(spec.MatrixError); ok { var status int switch eerr.ErrCode { - case "M_INVALID_PARAM": + case spec.ErrorInvalidParam: status = http.StatusBadRequest - case "M_NOT_FOUND": + case spec.ErrorNotFound: status = http.StatusNotFound default: status = http.StatusInternalServerError } - return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err) + return util.MatrixErrorResponse(status, string(eerr.ErrCode), eerr.Err) } util.GetLogger(ctx).WithError(err).Errorf(msg, args...) - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { diff --git a/clientapi/routing/receipt.go b/clientapi/routing/receipt.go index 0bbb20b9d..be6542979 100644 --- a/clientapi/routing/receipt.go +++ b/clientapi/routing/receipt.go @@ -48,7 +48,10 @@ func SetReceipt(req *http.Request, userAPI api.ClientUserAPI, syncProducer *prod case "m.fully_read": data, err := json.Marshal(fullyReadEvent{EventID: eventID}) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } dataReq := api.InputAccountDataRequest{ diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 12391d266..ed70e5c5c 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -16,6 +16,7 @@ package routing import ( "context" + "errors" "net/http" "time" @@ -121,17 +122,23 @@ func SendRedaction( err := proto.SetContent(r) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var queryRes roomserverAPI.QueryLatestEventsAndStateResponse e, err := eventutil.QueryAndBuildEvent(req.Context(), &proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { + if errors.Is(err, eventutil.ErrRoomNoExists{}) { return util.JSONResponse{ Code: http.StatusNotFound, JSON: spec.NotFound("Room does not exist"), @@ -140,7 +147,10 @@ func SendRedaction( domain := device.UserDomain() if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*types.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res := util.JSONResponse{ diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 615ff2011..565c41533 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -528,7 +528,10 @@ func Register( nres := &userapi.QueryNumericLocalpartResponse{} if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } r.Username = strconv.FormatInt(nres.ID, 10) } @@ -713,7 +716,7 @@ func handleRegistrationFlow( case nil: default: util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") - return util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError()} + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}} } // Add Recaptcha to the list of completed registration stages diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 9a60f5314..2a88ec380 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -402,7 +402,7 @@ func Test_register(t *testing.T) { enableRecaptcha: true, loginType: authtypes.LoginTypeRecaptcha, captchaBody: `i should fail for other reasons`, - wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError()}, + wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}}, }, } @@ -484,7 +484,7 @@ func Test_register(t *testing.T) { if !reflect.DeepEqual(r.Flows, cfg.Derived.Registration.Flows) { t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, cfg.Derived.Registration.Flows) } - case *spec.MatrixError: + case spec.MatrixError: if !reflect.DeepEqual(tc.wantResponse, resp) { t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) } @@ -541,7 +541,12 @@ func Test_register(t *testing.T) { resp = Register(req, userAPI, &cfg.ClientAPI) switch resp.JSON.(type) { - case *spec.MatrixError: + case spec.InternalServerError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + case spec.MatrixError: if !reflect.DeepEqual(tc.wantResponse, resp) { t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) } diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 8802d22a4..5a5296bf4 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -46,7 +46,10 @@ func GetTags( tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -83,7 +86,10 @@ func PutTag( tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if tagContent.Tags == nil { @@ -93,7 +99,10 @@ func PutTag( if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -125,7 +134,10 @@ func DeleteTag( tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Check whether the tag to be deleted exists @@ -141,7 +153,10 @@ func DeleteTag( if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 2e3cd4112..bc14642f8 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -149,7 +149,10 @@ func SendEvent( } aliasRes := &api.GetAliasesForRoomIDResponse{} if err = rsAPI.GetAliasesForRoomID(req.Context(), &api.GetAliasesForRoomIDRequest{RoomID: roomID}, aliasRes); err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var found int requestAliases := append(aliasReq.AltAliases, aliasReq.Alias) @@ -193,7 +196,10 @@ func SendEvent( false, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } timeToSubmitEvent := time.Since(startedSubmittingEvent) util.GetLogger(req.Context()).WithFields(logrus.Fields{ @@ -272,43 +278,51 @@ func generateSendEvent( err := proto.SetContent(r) if err != nil { util.GetLogger(ctx).WithError(err).Error("proto.SetContent failed") - resErr := spec.InternalServerError() - return nil, &resErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) if err != nil { - resErr := spec.InternalServerError() - return nil, &resErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var queryRes api.QueryLatestEventsAndStateResponse e, err := eventutil.QueryAndBuildEvent(ctx, &proto, cfg.Matrix, identity, evTime, rsAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { + switch specificErr := err.(type) { + case nil: + case eventutil.ErrRoomNoExists: return nil, &util.JSONResponse{ Code: http.StatusNotFound, JSON: spec.NotFound("Room does not exist"), } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + case gomatrixserverlib.BadJSONError: return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON(e.Error()), + JSON: spec.BadJSON(specificErr.Error()), } - } else if e, ok := err.(gomatrixserverlib.EventValidationError); ok { - if e.Code == gomatrixserverlib.EventValidationTooLarge { + case gomatrixserverlib.EventValidationError: + if specificErr.Code == gomatrixserverlib.EventValidationTooLarge { return nil, &util.JSONResponse{ Code: http.StatusRequestEntityTooLarge, - JSON: spec.BadJSON(e.Error()), + JSON: spec.BadJSON(specificErr.Error()), } } return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON(e.Error()), + JSON: spec.BadJSON(specificErr.Error()), } - } else if err != nil { + default: util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed") - resErr := spec.InternalServerError() - return nil, &resErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // check to see if this user can perform this operation diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go index 6d4af0728..58d3053e2 100644 --- a/clientapi/routing/sendtodevice.go +++ b/clientapi/routing/sendtodevice.go @@ -53,7 +53,10 @@ func SendToDevice( req.Context(), device.UserID, userID, deviceID, eventType, message, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("eduProducer.SendToDevice failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } } diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 17532a2dd..c5b29297a 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -58,7 +58,10 @@ func SendTyping( if err := syncProducer.SendTyping(req.Context(), userID, roomID, r.Typing, r.Timeout); err != nil { util.GetLogger(req.Context()).WithError(err).Error("eduProducer.Send failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index a418677ea..ad50cc80b 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -175,7 +175,10 @@ func SendServerNotice( }} if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil { util.GetLogger(ctx).WithError(err).Error("saveTagData failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } default: @@ -189,7 +192,10 @@ func SendServerNotice( err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipRes.IsInRoom { // re-invite the user @@ -237,7 +243,10 @@ func SendServerNotice( false, ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": e.EventID(), diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 75abbda91..319f4eba5 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -56,7 +56,10 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a StateToFetch: []gomatrixserverlib.StateKeyTuple{}, }, &stateRes); err != nil { util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !stateRes.RoomExists { return util.JSONResponse{ @@ -73,7 +76,10 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a content := map[string]string{} if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if visibility, ok := content["history_visibility"]; ok { worldReadable = visibility == "world_readable" @@ -99,7 +105,10 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. @@ -147,7 +156,10 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a }, &stateAfterRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } for _, ev := range stateAfterRes.StateEvents { stateEvents = append( @@ -202,7 +214,10 @@ func OnIncomingStateTypeRequest( StateToFetch: stateToFetch, }, &stateRes); err != nil { util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Look at the room state and see if we have a history visibility event @@ -213,7 +228,10 @@ func OnIncomingStateTypeRequest( content := map[string]string{} if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if visibility, ok := content["history_visibility"]; ok { worldReadable = visibility == "world_readable" @@ -239,7 +257,10 @@ func OnIncomingStateTypeRequest( }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. @@ -294,7 +315,10 @@ func OnIncomingStateTypeRequest( }, &stateAfterRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if len(stateAfterRes.StateEvents) > 0 { event = stateAfterRes.StateEvents[0] diff --git a/clientapi/routing/thirdparty.go b/clientapi/routing/thirdparty.go index 0ee218556..b805d4b51 100644 --- a/clientapi/routing/thirdparty.go +++ b/clientapi/routing/thirdparty.go @@ -33,7 +33,10 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev resp := &appserviceAPI.ProtocolResponse{} if err := asAPI.Protocols(req.Context(), &appserviceAPI.ProtocolRequest{Protocol: protocol}, resp); err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !resp.Exists { if protocol != "" { @@ -71,7 +74,10 @@ func User(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, device * Protocol: protocol, Params: params.Encode(), }, resp); err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !resp.Exists { return util.JSONResponse{ @@ -97,7 +103,10 @@ func Location(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, devi Protocol: protocol, Params: params.Encode(), }, resp); err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !resp.Exists { return util.JSONResponse{ diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 64fa59e40..5261a1407 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -60,28 +60,37 @@ func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *co if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.QueryLocalpartForThreePID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if len(res.Localpart) > 0 { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.MatrixError{ - ErrCode: "M_THREEPID_IN_USE", + ErrCode: spec.ErrorThreePIDInUse, Err: userdb.Err3PIDInUse.Error(), }, } } resp.SID, err = threepid.CreateSession(req.Context(), body, cfg, client) - if err == threepid.ErrNotTrusted { + switch err.(type) { + case nil: + case threepid.ErrNotTrusted: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CreateSession failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.NotTrusted(body.IDServer), } - } else if err != nil { + default: util.GetLogger(req.Context()).WithError(err).Error("threepid.CreateSession failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -102,21 +111,27 @@ func CheckAndSave3PIDAssociation( // Check if the association has been validated verified, address, medium, err := threepid.CheckAssociation(req.Context(), body.Creds, cfg, client) - if err == threepid.ErrNotTrusted { + switch err.(type) { + case nil: + case threepid.ErrNotTrusted: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.NotTrusted(body.Creds.IDServer), } - } else if err != nil { + default: util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !verified { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.MatrixError{ - ErrCode: "M_THREEPID_AUTH_FAILED", + ErrCode: spec.ErrorThreePIDAuthFailed, Err: "Failed to auth 3pid", }, } @@ -127,7 +142,10 @@ func CheckAndSave3PIDAssociation( err = threepid.PublishAssociation(req.Context(), body.Creds, device.UserID, cfg, client) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.PublishAssociation failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -135,7 +153,10 @@ func CheckAndSave3PIDAssociation( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ @@ -145,7 +166,10 @@ func CheckAndSave3PIDAssociation( Medium: medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -161,7 +185,10 @@ func GetAssociated3PIDs( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res := &api.QueryThreePIDsForLocalpartResponse{} @@ -171,7 +198,10 @@ func GetAssociated3PIDs( }, res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -192,7 +222,10 @@ func Forget3PID(req *http.Request, threepidAPI api.ClientUserAPI) util.JSONRespo Medium: body.Medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.PerformForgetThreePID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go index 43f8d3e24..a0b280789 100644 --- a/clientapi/routing/upgrade_room.go +++ b/clientapi/routing/upgrade_room.go @@ -68,13 +68,16 @@ func UpgradeRoom( JSON: spec.Forbidden(e.Error()), } default: - if errors.Is(err, eventutil.ErrRoomNoExists) { + if errors.Is(err, eventutil.ErrRoomNoExists{}) { return util.JSONResponse{ Code: http.StatusNotFound, JSON: spec.NotFound("Room does not exist"), } } - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go index f3db0cbe9..14a08b79c 100644 --- a/clientapi/routing/voip.go +++ b/clientapi/routing/voip.go @@ -60,7 +60,10 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mac.Write failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil)) diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 2e9c1261e..c296939d5 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -64,14 +64,34 @@ type idServerStoreInviteResponse struct { } var ( - // ErrMissingParameter is the error raised if a request for 3PID invite has - // an incomplete body - ErrMissingParameter = errors.New("'address', 'id_server' and 'medium' must all be supplied") - // ErrNotTrusted is the error raised if an identity server isn't in the list - // of trusted servers in the configuration file. - ErrNotTrusted = errors.New("untrusted server") + errMissingParameter = fmt.Errorf("'address', 'id_server' and 'medium' must all be supplied") + errNotTrusted = fmt.Errorf("untrusted server") ) +// ErrMissingParameter is the error raised if a request for 3PID invite has +// an incomplete body +type ErrMissingParameter struct{} + +func (e ErrMissingParameter) Error() string { + return errMissingParameter.Error() +} + +func (e ErrMissingParameter) Unwrap() error { + return errMissingParameter +} + +// ErrNotTrusted is the error raised if an identity server isn't in the list +// of trusted servers in the configuration file. +type ErrNotTrusted struct{} + +func (e ErrNotTrusted) Error() string { + return errNotTrusted.Error() +} + +func (e ErrNotTrusted) Unwrap() error { + return errNotTrusted +} + // CheckAndProcessInvite analyses the body of an incoming membership request. // If the fields relative to a third-party-invite are all supplied, lookups the // matching Matrix ID from the given identity server. If no Matrix ID is @@ -99,7 +119,7 @@ func CheckAndProcessInvite( } else if body.Address == "" || body.IDServer == "" || body.Medium == "" { // If at least one of the 3PID-specific fields is supplied but not all // of them, return an error - err = ErrMissingParameter + err = ErrMissingParameter{} return } diff --git a/clientapi/threepid/threepid.go b/clientapi/threepid/threepid.go index 1fe573b1b..d61052cc0 100644 --- a/clientapi/threepid/threepid.go +++ b/clientapi/threepid/threepid.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) // EmailAssociationRequest represents the request defined at https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-register-email-requesttoken @@ -133,7 +134,7 @@ func CheckAssociation( return false, "", "", err } - if respBody.ErrCode == "M_SESSION_NOT_VALIDATED" { + if respBody.ErrCode == string(spec.ErrorSessionNotValidated) { return false, "", "", nil } else if len(respBody.ErrCode) > 0 { return false, "", "", errors.New(respBody.Error) @@ -186,5 +187,5 @@ func isTrusted(idServer string, cfg *config.ClientAPI) error { return nil } } - return ErrNotTrusted + return ErrNotTrusted{} } diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 81b61322c..9e1595053 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -98,7 +98,10 @@ func Backfill( // Query the roomserver. if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Filter any event that's not from the requested room out. diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 318c0a349..a54ff0d9c 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -38,7 +38,10 @@ func GetUserDevices( } if res.Error != nil { util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } sigReq := &api.QuerySignaturesRequest{ diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index bdfe2c821..993d40466 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -183,7 +183,10 @@ func processInvite( verifyResults, err := keys.VerifyJSONs(ctx, verifyRequests) if err != nil { util.GetLogger(ctx).WithError(err).Error("keys.VerifyJSONs failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if verifyResults[0].Error != nil { return util.JSONResponse{ @@ -211,7 +214,7 @@ func processInvite( util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } @@ -232,7 +235,7 @@ func processInvite( sentry.CaptureException(err) return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index c301785cf..cc22690a9 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -15,6 +15,7 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" @@ -33,153 +34,187 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) +type JoinRoomQuerier struct { + roomserver api.FederationRoomserverAPI +} + +func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { + return rq.roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) +} + +func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { + return rq.roomserver.InvitePending(ctx, roomID, userID) +} + +func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { + roomInfo, err := rq.roomserver.QueryRoomInfo(ctx, roomID) + if err != nil || roomInfo == nil || roomInfo.IsStub() { + return nil, err + } + + req := api.QueryServerJoinedToRoomRequest{ + ServerName: localServerName, + RoomID: roomID.String(), + } + res := api.QueryServerJoinedToRoomResponse{} + if err = rq.roomserver.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) + } + + userJoinedToRoom, err := rq.roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + locallyJoinedUsers, err := rq.roomserver.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + return &gomatrixserverlib.RestrictedRoomJoinInfo{ + LocalServerInRoom: res.RoomExists && res.IsInRoom, + UserJoinedToRoom: userJoinedToRoom, + JoinedUsers: locallyJoinedUsers, + }, nil +} + // MakeJoin implements the /make_join API func MakeJoin( httpReq *http.Request, request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, - roomID, userID string, + roomID spec.RoomID, userID spec.UserID, remoteVersions []gomatrixserverlib.RoomVersion, ) util.JSONResponse { - roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID) + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String()) if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("failed obtaining room version") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } - // Check that the room that the remote side is trying to join is actually - // one of the room versions that they listed in their supported ?ver= in - // the make_join URL. - // https://matrix.org/docs/spec/server_server/r0.1.3#get-matrix-federation-v1-make-join-roomid-userid - remoteSupportsVersion := false - for _, v := range remoteVersions { - if v == roomVersion { - remoteSupportsVersion = true - break - } - } - // If it isn't, stop trying to join the room. - if !remoteSupportsVersion { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.IncompatibleRoomVersion(string(roomVersion)), - } - } - - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid UserID"), - } - } - if domain != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("The join must be sent by the server of the user"), - } - } - - // Check if we think we are still joined to the room - inRoomReq := &api.QueryServerJoinedToRoomRequest{ + req := api.QueryServerJoinedToRoomRequest{ ServerName: cfg.Matrix.ServerName, - RoomID: roomID, + RoomID: roomID.String(), } - inRoomRes := &api.QueryServerJoinedToRoomResponse{} - if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), inRoomReq, inRoomRes); err != nil { + res := api.QueryServerJoinedToRoomResponse{} + if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") - return spec.InternalServerError() - } - if !inRoomRes.RoomExists { return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound(fmt.Sprintf("Room ID %q was not found on this server", roomID)), - } - } - if !inRoomRes.IsInRoom { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound(fmt.Sprintf("Room ID %q has no remaining users on this server", roomID)), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } - // Check if the restricted join is allowed. If the room doesn't - // support restricted joins then this is effectively a no-op. - res, authorisedVia, err := checkRestrictedJoin(httpReq, rsAPI, roomVersion, roomID, userID) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("checkRestrictedJoin failed") - return spec.InternalServerError() - } else if res != nil { - return *res + createJoinTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { + identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) + return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) + } + + queryRes := api.QueryLatestEventsAndStateResponse{ + RoomVersion: roomVersion, + } + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) + switch e := err.(type) { + case nil: + case eventutil.ErrRoomNoExists: + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + return nil, nil, spec.NotFound("Room does not exist") + case gomatrixserverlib.BadJSONError: + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + return nil, nil, spec.BadJSON(e.Error()) + default: + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + return nil, nil, spec.InternalServerError{} + } + + stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) + for i, stateEvent := range queryRes.StateEvents { + stateEvents[i] = stateEvent.PDU + } + return event, stateEvents, nil } - // Try building an event for the server - proto := gomatrixserverlib.ProtoEvent{ - Sender: userID, - RoomID: roomID, - Type: "m.room.member", - StateKey: &userID, - } - content := gomatrixserverlib.MemberContent{ - Membership: spec.Join, - AuthorisedVia: authorisedVia, - } - if err = proto.SetContent(content); err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed") - return spec.InternalServerError() + roomQuerier := JoinRoomQuerier{ + roomserver: rsAPI, } - identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) - if err != nil { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound( - fmt.Sprintf("Server name %q does not exist", request.Destination()), - ), + input := gomatrixserverlib.HandleMakeJoinInput{ + Context: httpReq.Context(), + UserID: userID, + RoomID: roomID, + RoomVersion: roomVersion, + RemoteVersions: remoteVersions, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + LocalServerInRoom: res.RoomExists && res.IsInRoom, + RoomQuerier: &roomQuerier, + BuildEventTemplate: createJoinTemplate, + } + response, internalErr := gomatrixserverlib.HandleMakeJoin(input) + if internalErr != nil { + switch e := internalErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorUnableToAuthoriseJoin: + code = http.StatusBadRequest + case spec.ErrorBadJSON: + code = http.StatusBadRequest + } + + return util.JSONResponse{ + Code: code, + JSON: e, + } + case spec.IncompatibleRoomVersionError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: e, + } + default: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), + } } } - queryRes := api.QueryLatestEventsAndStateResponse{ - RoomVersion: roomVersion, - } - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { + if response == nil { + util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeJoin returned invalid response") return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound("Room does not exist"), - } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON(e.Error()), - } - } else if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") - return spec.InternalServerError() - } - - // Check that the join is allowed or not - stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) - for i := range queryRes.StateEvents { - stateEvents[i] = queryRes.StateEvents[i].PDU - } - - provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(event.PDU, &provider); err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden(err.Error()), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } return util.JSONResponse{ Code: http.StatusOK, JSON: map[string]interface{}{ - "event": proto, - "room_version": roomVersion, + "event": response.JoinTemplateEvent, + "room_version": response.RoomVersion, }, } } @@ -201,7 +236,7 @@ func SendJoin( util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) @@ -311,7 +346,10 @@ func SendJoin( verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if verifyResults[0].Error != nil { return util.JSONResponse{ @@ -331,7 +369,10 @@ func SendJoin( }, &stateAndAuthChainResponse) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryStateAndAuthChain failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !stateAndAuthChainResponse.RoomExists { @@ -427,7 +468,10 @@ func SendJoin( JSON: spec.Forbidden(response.ErrMsg), } } - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -449,74 +493,6 @@ func SendJoin( } } -// checkRestrictedJoin finds out whether or not we can assist in processing -// a restricted room join. If the room version does not support restricted -// joins then this function returns with no side effects. This returns three -// values: -// - an optional JSON response body (i.e. M_UNABLE_TO_AUTHORISE_JOIN) which -// should always be sent back to the client if one is specified -// - a user ID of an authorising user, typically a user that has power to -// issue invites in the room, if one has been found -// - an error if there was a problem finding out if this was allowable, -// like if the room version isn't known or a problem happened talking to -// the roomserver -func checkRestrictedJoin( - httpReq *http.Request, - rsAPI api.FederationRoomserverAPI, - roomVersion gomatrixserverlib.RoomVersion, - roomID, userID string, -) (*util.JSONResponse, string, error) { - verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) - if err != nil { - return nil, "", err - } - if !verImpl.MayAllowRestrictedJoinsInEventAuth() { - return nil, "", nil - } - req := &api.QueryRestrictedJoinAllowedRequest{ - RoomID: roomID, - UserID: userID, - } - res := &api.QueryRestrictedJoinAllowedResponse{} - if err := rsAPI.QueryRestrictedJoinAllowed(httpReq.Context(), req, res); err != nil { - return nil, "", err - } - - switch { - case !res.Restricted: - // The join rules for the room don't restrict membership. - return nil, "", nil - - case !res.Resident: - // The join rules restrict membership but our server isn't currently - // joined to all of the allowed rooms, so we can't actually decide - // whether or not to allow the user to join. This error code should - // tell the joining server to try joining via another resident server - // instead. - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.UnableToAuthoriseJoin("This server cannot authorise the join."), - }, "", nil - - case !res.Allowed: - // The join rules restrict membership, our server is in the relevant - // rooms and the user wasn't joined to join any of the allowed rooms - // and therefore can't join this room. - return &util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You are not joined to any matching rooms."), - }, "", nil - - default: - // The join rules restrict membership, our server is in the relevant - // rooms and the user was allowed to join because they belong to one - // of the allowed rooms. We now need to pick one of our own local users - // from within the room to use as the authorising user ID, so that it - // can be referred to from within the membership content. - return nil, res.AuthorisedVia, nil - } -} - type eventsByDepth []*types.HeaderedEvent func (e eventsByDepth) Len() int { diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index d85de73d8..3d8ff2dea 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -67,7 +67,10 @@ func QueryDeviceKeys( }, &queryRes) if queryRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: 200, @@ -119,7 +122,10 @@ func ClaimOneTimeKeys( }, &claimRes) if claimRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: 200, @@ -243,7 +249,10 @@ func NotaryKeys( j, err := json.Marshal(keys) if err != nil { logrus.WithError(err).Errorf("Failed to marshal %q response", serverName) - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } js, err := gomatrixserverlib.SignJSON( @@ -251,7 +260,10 @@ func NotaryKeys( ) if err != nil { logrus.WithError(err).Errorf("Failed to sign %q response", serverName) - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } response.ServerKeys = append(response.ServerKeys, js) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index fdfbf15d7..e65403404 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -60,7 +60,10 @@ func MakeLeave( err = proto.SetContent(map[string]interface{}{"membership": spec.Leave}) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("proto.SetContent failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) @@ -75,19 +78,26 @@ func MakeLeave( var queryRes api.QueryLatestEventsAndStateResponse event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { + switch e := err.(type) { + case nil: + case eventutil.ErrRoomNoExists: + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") return util.JSONResponse{ Code: http.StatusNotFound, JSON: spec.NotFound("Room does not exist"), } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + case gomatrixserverlib.BadJSONError: + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON(e.Error()), } - } else if err != nil { + default: util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // If the user has already left then just return their last leave @@ -233,7 +243,10 @@ func SendLeave( err = rsAPI.QueryLatestEventsAndState(httpReq.Context(), queryReq, queryRes) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryLatestEventsAndState failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // The room doesn't exist or we weren't ever joined to it. Might as well // no-op here. @@ -279,7 +292,10 @@ func SendLeave( verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if verifyResults[0].Error != nil { return util.JSONResponse{ @@ -327,7 +343,10 @@ func SendLeave( JSON: spec.Forbidden(response.ErrMsg), } } - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index f8dd9e4f1..f57d30204 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -63,7 +63,10 @@ func GetMissingEvents( &eventsResponse, ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryMissingEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } eventsResponse.Events = filterEvents(eventsResponse.Events, roomID) diff --git a/federationapi/routing/peek.go b/federationapi/routing/peek.go index 9e924556f..f5003b147 100644 --- a/federationapi/routing/peek.go +++ b/federationapi/routing/peek.go @@ -40,7 +40,7 @@ func Peek( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index 7d6cfcaa6..e6a488ba3 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -53,7 +53,10 @@ func GetProfile( profile, err := userAPI.QueryProfile(httpReq.Context(), userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryProfile failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var res interface{} diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index 59ff4eb2a..213d1631a 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -39,7 +39,10 @@ func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.FederationRoomser } response, err := publicRooms(req.Context(), request, rsAPI) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, @@ -106,8 +109,10 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO // In that case, we want to assign 0 so we ignore the error if err != nil && len(httpReq.FormValue("limit")) > 0 { util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") - reqErr := spec.InternalServerError() - return &reqErr + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } request.Limit = int16(limit) request.Since = httpReq.FormValue("since") diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 233290e2e..2e845f32c 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -61,7 +61,10 @@ func RoomAliasToID( queryRes := &roomserverAPI.GetRoomIDForAliasResponse{} if err = rsAPI.GetRoomIDForAlias(httpReq.Context(), queryReq, queryRes); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if queryRes.RoomID != "" { @@ -69,7 +72,10 @@ func RoomAliasToID( var serverQueryRes federationAPI.QueryJoinedHostServerNamesInRoomResponse if err = senderAPI.QueryJoinedHostServerNamesInRoom(httpReq.Context(), &serverQueryReq, &serverQueryRes); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("senderAPI.QueryJoinedHostServerNamesInRoom failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } resp = fclient.RespDirectory{ @@ -98,7 +104,10 @@ func RoomAliasToID( // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. util.GetLogger(httpReq.Context()).WithError(err).Error("federation.LookupRoomAlias failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index f62a8f46c..44faad918 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -312,8 +312,6 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] - userID := vars["userID"] queryVars := httpReq.URL.Query() remoteVersions := []gomatrixserverlib.RoomVersion{} if vers, ok := queryVars["ver"]; ok { @@ -328,8 +326,25 @@ func Setup( // https://matrix.org/docs/spec/server_server/r0.1.3#get-matrix-federation-v1-make-join-roomid-userid remoteVersions = append(remoteVersions, gomatrixserverlib.RoomVersionV1) } + + userID, err := spec.NewUserID(vars["userID"], true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Invalid UserID"), + } + } + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Invalid RoomID"), + } + } + + logrus.Debugf("Processing make_join for user %s, room %s", userID.String(), roomID.String()) return MakeJoin( - httpReq, request, cfg, rsAPI, roomID, userID, remoteVersions, + httpReq, request, cfg, rsAPI, *roomID, *userID, remoteVersions, ) }, )).Methods(http.MethodGet) @@ -353,7 +368,7 @@ func Setup( body = []interface{}{ res.Code, res.JSON, } - jerr, ok := res.JSON.(*spec.MatrixError) + jerr, ok := res.JSON.(spec.MatrixError) if ok { body = jerr } @@ -419,7 +434,7 @@ func Setup( body = []interface{}{ res.Code, res.JSON, } - jerr, ok := res.JSON.(*spec.MatrixError) + jerr, ok := res.JSON.(spec.MatrixError) if ok { body = jerr } @@ -566,7 +581,7 @@ func MakeFedAPI( go wakeup.Wakeup(req.Context(), fedReq.Origin()) vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { - return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + return util.MatrixErrorResponse(400, string(spec.ErrorUnrecognized), "badly encoded query params") } jsonRes := f(req, fedReq, vars) diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index adfafe740..beeb52495 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -80,7 +80,10 @@ func CreateInvitesFrom3PIDInvites( ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("createInviteFrom3PIDInvite failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if event != nil { evs = append(evs, &types.HeaderedEvent{PDU: event}) @@ -100,7 +103,10 @@ func CreateInvitesFrom3PIDInvites( false, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -176,7 +182,10 @@ func ExchangeThirdPartyInvite( } } else if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("buildMembershipEvent failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Ask the requesting server to sign the newly created event so we know it @@ -184,22 +193,34 @@ func ExchangeThirdPartyInvite( inviteReq, err := fclient.NewInviteV2Request(event, nil) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Errorf("unknown room version: %s", roomVersion) - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } inviteEvent, err := verImpl.NewEventFromUntrustedJSON(signedEvent.Event) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Send the event to the roomserver @@ -216,7 +237,10 @@ func ExchangeThirdPartyInvite( false, ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/go.mod b/go.mod index bd1d43fcb..eff9e50f7 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230509222610-6fd532036ab6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -34,7 +34,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 - github.com/sirupsen/logrus v1.9.0 + github.com/sirupsen/logrus v1.9.1 github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 733d6e24f..faf70c6df 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230509222610-6fd532036ab6 h1:cF6fNfxC73fU9zT3pgzDXI9NDihAdnilqqGcpDWgNP4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230509222610-6fd532036ab6/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2 h1:V36yCWt2CoSfI1xr6WYJ9Mb3eyl95SknMRLGFvEuYak= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.1 h1:Ou41VVR3nMWWmTiEUnj0OlsgOSCUFgsPAOl6jRIcVtQ= +github.com/sirupsen/logrus v1.9.1/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index dff459684..79882d8d8 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -31,7 +31,17 @@ import ( // ErrRoomNoExists is returned when trying to lookup the state of a room that // doesn't exist -var ErrRoomNoExists = errors.New("room does not exist") +var errRoomNoExists = fmt.Errorf("room does not exist") + +type ErrRoomNoExists struct{} + +func (e ErrRoomNoExists) Error() string { + return errRoomNoExists.Error() +} + +func (e ErrRoomNoExists) Unwrap() error { + return errRoomNoExists +} // QueryAndBuildEvent builds a Matrix event using the event builder and roomserver query // API client provided. If also fills roomserver query API response (if provided) @@ -116,7 +126,7 @@ func addPrevEventsToEvent( queryRes *api.QueryLatestEventsAndStateResponse, ) error { if !queryRes.RoomExists { - return ErrRoomNoExists + return ErrRoomNoExists{} } verImpl, err := gomatrixserverlib.GetRoomVersion(queryRes.RoomVersion) diff --git a/internal/httputil/routing.go b/internal/httputil/routing.go index c733c8ce7..2052c798f 100644 --- a/internal/httputil/routing.go +++ b/internal/httputil/routing.go @@ -15,10 +15,12 @@ package httputil import ( + "encoding/json" "net/http" "net/url" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib/spec" ) // URLDecodeMapValues is a function that iterates through each of the items in a @@ -66,13 +68,15 @@ func NewRouters() Routers { var NotAllowedHandler = WrapHandlerInCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusMethodNotAllowed) w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell + unrecognizedErr, _ := json.Marshal(spec.Unrecognized("Unrecognized request")) // nolint:misspell + _, _ = w.Write(unrecognizedErr) // nolint:misspell })) var NotFoundCORSHandler = WrapHandlerInCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell + unrecognizedErr, _ := json.Marshal(spec.Unrecognized("Unrecognized request")) // nolint:misspell + _, _ = w.Write(unrecognizedErr) // nolint:misspell })) func (r *Routers) configureHTTPErrors() { diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 5061d4762..5ac1d076b 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -184,8 +184,10 @@ func (r *uploadRequest) doUpload( if err != nil { fileutils.RemoveDir(tmpDir, r.Logger) r.Logger.WithError(err).Error("Error querying the database by hash.") - resErr := spec.InternalServerError() - return &resErr + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if existingMetadata != nil { // The file already exists, delete the uploaded temporary file. @@ -194,8 +196,10 @@ func (r *uploadRequest) doUpload( mediaID, merr := r.generateMediaID(ctx, db) if merr != nil { r.Logger.WithError(merr).Error("Failed to generate media ID for existing file") - resErr := spec.InternalServerError() - return &resErr + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Then amend the upload metadata. @@ -217,8 +221,10 @@ func (r *uploadRequest) doUpload( if err != nil { fileutils.RemoveDir(tmpDir, r.Logger) r.Logger.WithError(err).Error("Failed to generate media ID for new upload") - resErr := spec.InternalServerError() - return &resErr + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go index f6e556119..f11b0a7c5 100644 --- a/relayapi/routing/routing.go +++ b/relayapi/routing/routing.go @@ -122,7 +122,7 @@ func MakeRelayAPI( }() vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { - return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + return util.MatrixErrorResponse(400, string(spec.ErrorUnrecognized), "badly encoded query params") } jsonRes := f(req, fedReq, vars) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index ab1ec28f8..f2e2bf84a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -8,6 +8,7 @@ import ( asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/roomserver/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -224,6 +225,12 @@ type FederationRoomserverAPI interface { PerformInvite(ctx context.Context, req *PerformInviteRequest) error // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error + + CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) + InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) + QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) + UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) + LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) } type KeyserverRoomserverAPI interface { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index a539efd1d..fadc8bcfc 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -52,7 +52,7 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil, err } if roomInfo == nil || roomInfo.IsStub() { - return nil, eventutil.ErrRoomNoExists + return nil, eventutil.ErrRoomNoExists{} } memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) @@ -240,7 +240,7 @@ func (r *Admin) PerformAdminDownloadState( } if roomInfo == nil || roomInfo.IsStub() { - return eventutil.ErrRoomNoExists + return eventutil.ErrRoomNoExists{} } fwdExtremities, _, depth, err := r.DB.LatestEventIDs(ctx, roomInfo.RoomNID) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index a836eb1ae..5f4ad1861 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -145,7 +145,7 @@ func (r *Joiner) performJoinRoomByAlias( return r.performJoinRoomByID(ctx, req) } -// TODO: Break this function up a bit +// TODO: Break this function up a bit & move to GMSL // nolint:gocyclo func (r *Joiner) performJoinRoomByID( ctx context.Context, @@ -286,7 +286,7 @@ func (r *Joiner) performJoinRoomByID( } event, err := eventutil.QueryAndBuildEvent(ctx, &proto, r.Cfg.Matrix, identity, time.Now(), r.RSAPI, &buildRes) - switch err { + switch err.(type) { case nil: // The room join is local. Send the new join event into the // roomserver. First of all check that the user isn't already @@ -328,7 +328,7 @@ func (r *Joiner) performJoinRoomByID( // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. if len(req.ServerNames) == 0 { - return "", "", eventutil.ErrRoomNoExists + return "", "", eventutil.ErrRoomNoExists{} } } diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index e88cb1dcc..abe63145a 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -274,7 +274,7 @@ func publishNewRoomAndUnpublishOldRoom( func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error { if _, err := r.URSAPI.QueryRoomVersionForRoom(ctx, roomID); err != nil { - return eventutil.ErrRoomNoExists + return eventutil.ErrRoomNoExists{} } return nil } @@ -556,15 +556,18 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user } var queryRes api.QueryLatestEventsAndStateResponse headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &proto, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { - return nil, err - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + switch e := err.(type) { + case nil: + case eventutil.ErrRoomNoExists: return nil, e - } else if e, ok := err.(gomatrixserverlib.EventValidationError); ok { + case gomatrixserverlib.BadJSONError: return nil, e - } else if err != nil { + case gomatrixserverlib.EventValidationError: + return nil, e + default: return nil, fmt.Errorf("failed to build new %q event: %w", proto.Type, err) } + // check to see if this user can perform this operation stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) for i := range queryRes.StateEvents { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 27c0dd0c0..e4dac45ea 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -858,6 +858,49 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq return nil } +func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { + pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String()) + return pending, err +} + +func (r *Queryer) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) { + return r.DB.RoomInfo(ctx, roomID.String()) +} + +func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { + return r.DB.GetStateEvent(ctx, roomID.String(), string(eventType), "") +} + +func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) { + _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String()) + return isIn, err +} + +func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) { + joinNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true) + if err != nil { + return nil, err + } + + events, err := r.DB.Events(ctx, roomVersion, joinNIDs) + if err != nil { + return nil, err + } + + // For each of the joined users, let's see if we can get a valid + // membership event. + joinedUsers := []gomatrixserverlib.PDU{} + for _, event := range events { + if event.Type() != spec.MRoomMember || event.StateKey() == nil { + continue // shouldn't happen + } + + joinedUsers = append(joinedUsers, event) + } + + return joinedUsers, nil +} + // nolint:gocyclo func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse) error { // Look up if we know anything about the room. If it doesn't exist diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index e9d61fede..f468b048a 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -427,8 +427,10 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent") - resErr := spec.InternalServerError() - return nil, &resErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var childEvents []*types.HeaderedEvent for _, child := range children { diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 8ff656e7a..ac17d39d2 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -56,7 +56,10 @@ func Context( ) util.JSONResponse { snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -87,7 +90,10 @@ func Context( membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { logrus.WithError(err).Error("unable to query membership") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipRes.RoomExists { return util.JSONResponse{ @@ -117,7 +123,10 @@ func Context( } } logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // verify the user is allowed to see the context for this room/event @@ -125,7 +134,10 @@ func Context( filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } logrus.WithFields(logrus.Fields{ "duration": time.Since(startTime), @@ -141,20 +153,29 @@ func Context( eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch before events") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch after events") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } startTime = time.Now() eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } logrus.WithFields(logrus.Fields{ @@ -166,7 +187,10 @@ func Context( state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to fetch current room state") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll) @@ -180,7 +204,10 @@ func Context( newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { logrus.WithError(err).Error("unable to load membership events") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } diff --git a/syncapi/routing/filter.go b/syncapi/routing/filter.go index 5152e1f81..c4eecbdb8 100644 --- a/syncapi/routing/filter.go +++ b/syncapi/routing/filter.go @@ -43,7 +43,10 @@ func GetFilter( localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } filter := synctypes.DefaultFilter() @@ -83,7 +86,10 @@ func PutFilter( localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var filter synctypes.Filter @@ -122,7 +128,10 @@ func PutFilter( filterID, err := syncDB.PutFilter(req.Context(), localpart, &filter) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncDB.PutFilter failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index e3d77cc33..0d3d412f6 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -51,13 +51,19 @@ func GetEvent( }) if err != nil { logger.WithError(err).Error("GetEvent: syncDB.NewDatabaseTransaction failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } events, err := db.Events(ctx, []string{eventID}) if err != nil { logger.WithError(err).Error("GetEvent: syncDB.Events failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // The requested event does not exist in our database @@ -81,7 +87,7 @@ func GetEvent( logger.WithError(err).Error("GetEvent: internal.ApplyHistoryVisibilityFilter failed") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: spec.InternalServerError(), + JSON: spec.InternalServerError{}, } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 5a66009c8..7d2e137d3 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -67,7 +67,10 @@ func GetMemberships( var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !queryRes.HasBeenInRoom { @@ -86,7 +89,10 @@ func GetMemberships( db, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } defer db.Rollback() // nolint: errcheck @@ -98,7 +104,10 @@ func GetMemberships( atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } } @@ -106,13 +115,19 @@ func GetMemberships( eventIDs, err := db.SelectMemberships(req.Context(), roomID, atToken, membership, notMembership) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("db.SelectMemberships failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } qryRes := &api.QueryEventsByIDResponse{} if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } result := qryRes.Events @@ -124,7 +139,10 @@ func GetMemberships( var content databaseJoinedMember if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res.Joined[ev.Sender()] = joinedMember(content) } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 4d3c9e2eb..58f663d0b 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -81,7 +81,10 @@ func OnIncomingMessagesRequest( // request that requires backfilling from the roomserver or federation. snapshot, err := db.NewDatabaseTransaction(req.Context()) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -89,7 +92,10 @@ func OnIncomingMessagesRequest( // check if the user has already forgotten about this room membershipResp, err := getMembershipForUser(req.Context(), roomID, device.UserID, rsAPI) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipResp.RoomExists { return util.JSONResponse{ @@ -151,7 +157,10 @@ func OnIncomingMessagesRequest( from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } } @@ -173,7 +182,10 @@ func OnIncomingMessagesRequest( to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } } @@ -232,7 +244,10 @@ func OnIncomingMessagesRequest( clientEvents, start, end, err := mReq.retrieveEvents() if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } util.GetLogger(req.Context()).WithFields(logrus.Fields{ @@ -253,7 +268,10 @@ func OnIncomingMessagesRequest( membershipEvents, err := applyLazyLoadMembers(req.Context(), device, snapshot, roomID, clientEvents, lazyLoadCache) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...) } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 2bf11a566..8374bf5b0 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -80,7 +80,10 @@ func Relations( snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { logrus.WithError(err).Error("Failed to get snapshot for relations") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 88c5c5045..9ad0c0476 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -162,7 +162,10 @@ func Setup( } var nextBatch *string if err := req.ParseForm(); err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if req.Form.Has("next_batch") { nb := req.FormValue("next_batch") diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 986284d06..b7191873e 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -55,7 +55,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if from != nil && *from != "" { nextBatch, err = strconv.Atoi(*from) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -65,7 +68,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -73,7 +79,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts // only search rooms the user is actually joined to joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join") if err != nil { - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if len(joinedRooms) == 0 { return util.JSONResponse{ @@ -115,7 +124,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ) if err != nil { logrus.WithError(err).Error("failed to search fulltext") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } logrus.Debugf("Search took %s", result.Took) @@ -155,7 +167,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts evs, err := syncDB.Events(ctx, wantEvents) if err != nil { logrus.WithError(err).Error("failed to get events from database") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } groups := make(map[string]RoomResult) @@ -173,12 +188,18 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq) if err != nil { logrus.WithError(err).Error("failed to get context events") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter) if err != nil { logrus.WithError(err).Error("failed to get start/end") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } profileInfos := make(map[string]ProfileInfoResponse) @@ -221,7 +242,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to get current state") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync) } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 09e5dee17..5a92c70e1 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -536,12 +536,18 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use syncReq, err := newSyncRequest(req, *device, rp.db) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } snapshot, err := rp.db.NewDatabaseSnapshot(req.Context()) if err != nil { logrus.WithError(err).Error("Failed to acquire database snapshot for key change") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -552,7 +558,10 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("Failed to DeviceListCatchup info") - return spec.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } succeeded = true return util.JSONResponse{ From 345f025ee3654d120b9e668e943a4f2d428c12c7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 17 May 2023 17:44:59 +0200 Subject: [PATCH 04/35] Bump github.com/docker/distribution from 2.8.1+incompatible to 2.8.2+incompatible (#3082) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [github.com/docker/distribution](https://github.com/docker/distribution) from 2.8.1+incompatible to 2.8.2+incompatible.
Release notes

Sourced from github.com/docker/distribution's releases.

v2.8.2

What's Changed

Full Changelog: https://github.com/distribution/distribution/compare/v2.8.1...v2.8.2

v2.8.2-beta.2

What's Changed

Full Changelog: https://github.com/distribution/distribution/compare/v2.8.1...v2.8.2-beta.2

v2.8.2-beta.1

NOTE: This is a pre-release that does not contain any artifacts!

What's Changed

Full Changelog: https://github.com/distribution/distribution/compare/v2.8.1...v2.8.2-beta.1

Commits
  • 7c354a4 Merge pull request #3915 from distribution/2.8.2-release-notes
  • a173a9c Add v2.8.2 release notes
  • 4894d35 Merge pull request #3914 from vvoland/handle-forbidden-28
  • f067f66 Merge pull request #3783 from ndeloof/accept-encoding-28
  • 483ad69 registry/errors: Parse http forbidden as denied
  • 2b0f84d Revert "registry/client: set Accept: identity header when getting layers"
  • 320d6a1 Merge pull request #3912 from distribution/2.8.2-beta.2-release-notes
  • 5f3ca1b Add release notes for 2.8.2-beta.2 release
  • cb840f6 Merge pull request #3911 from thaJeztah/2.8_backport_fix_releaser_filenames
  • e884644 Dockerfile: fix filenames of artifacts
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/docker/distribution&package-manager=go_modules&previous-version=2.8.1+incompatible&new-version=2.8.2+incompatible)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index eff9e50f7..87fcbf0c0 100644 --- a/go.mod +++ b/go.mod @@ -77,7 +77,7 @@ require ( github.com/blevesearch/zapx/v15 v15.3.8 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/docker/distribution v2.8.1+incompatible // indirect + github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect diff --git a/go.sum b/go.sum index faf70c6df..9589ed168 100644 --- a/go.sum +++ b/go.sum @@ -133,8 +133,8 @@ github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWa github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= -github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= -github.com/docker/distribution v2.8.1+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= +github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= +github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker v20.10.24+incompatible h1:Ugvxm7a8+Gz6vqQYQQ2W7GYq5EUPaAiuPgIfVyI3dYE= github.com/docker/docker v20.10.24+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= From 027a9b8ce0a7e2d577e2c41f9de7a6fe42ace655 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 18 May 2023 13:41:47 -0600 Subject: [PATCH 05/35] Fix bug with nil interface return & add test --- roomserver/internal/query/query.go | 6 ++++- roomserver/internal/query/query_test.go | 33 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index e4dac45ea..35cafd0ec 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -868,7 +868,11 @@ func (r *Queryer) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types } func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { - return r.DB.GetStateEvent(ctx, roomID.String(), string(eventType), "") + res, err := r.DB.GetStateEvent(ctx, roomID.String(), string(eventType), "") + if res == nil { + return nil, err + } + return res, err } func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) { diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 2ebf7f334..b6715cb00 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -18,10 +18,16 @@ import ( "context" "encoding/json" "testing" + "time" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // used to implement RoomserverInternalAPIEventDB to test getAuthChain @@ -155,3 +161,30 @@ func TestGetAuthChainMultiple(t *testing.T) { t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs) } } + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.Open(context.Background(), cm, &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}, caches) + if err != nil { + t.Fatalf("failed to create Database: %v", err) + } + return db, close +} + +func TestCurrentEventIsNil(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + querier := Queryer{ + DB: db, + } + + roomID, _ := spec.NewRoomID("!room:server") + event, _ := querier.CurrentStateEvent(context.Background(), *roomID, spec.MRoomMember, "@user:server") + if event != nil { + t.Fatal("Event should equal nil, most likely this is failing because the interface type is not nil, but the value is.") + } + }) +} From 2eae8dc489f056df5aec0ee4ace1b8ba8260e18e Mon Sep 17 00:00:00 2001 From: devonh Date: Fri, 19 May 2023 16:27:01 +0000 Subject: [PATCH 06/35] Move SendJoin logic to GMSL (#3084) Moves the core matrix logic for handling the send_join endpoint over to gmsl. --- federationapi/routing/join.go | 339 ++++++++++------------------- federationapi/routing/routing.go | 26 ++- go.mod | 4 +- go.sum | 8 +- roomserver/api/api.go | 1 + roomserver/internal/query/query.go | 2 +- 6 files changed, 147 insertions(+), 233 deletions(-) diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index cc22690a9..cbdeca51e 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -16,7 +16,6 @@ package routing import ( "context" - "encoding/json" "fmt" "net/http" "sort" @@ -160,45 +159,43 @@ func MakeJoin( BuildEventTemplate: createJoinTemplate, } response, internalErr := gomatrixserverlib.HandleMakeJoin(input) - if internalErr != nil { - switch e := internalErr.(type) { - case nil: - case spec.InternalServerError: - util.GetLogger(httpReq.Context()).WithError(internalErr) - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - case spec.MatrixError: - util.GetLogger(httpReq.Context()).WithError(internalErr) - code := http.StatusInternalServerError - switch e.ErrCode { - case spec.ErrorForbidden: - code = http.StatusForbidden - case spec.ErrorNotFound: - code = http.StatusNotFound - case spec.ErrorUnableToAuthoriseJoin: - code = http.StatusBadRequest - case spec.ErrorBadJSON: - code = http.StatusBadRequest - } + switch e := internalErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorUnableToAuthoriseJoin: + fallthrough // http.StatusBadRequest + case spec.ErrorBadJSON: + code = http.StatusBadRequest + } - return util.JSONResponse{ - Code: code, - JSON: e, - } - case spec.IncompatibleRoomVersionError: - util.GetLogger(httpReq.Context()).WithError(internalErr) - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: e, - } - default: - util.GetLogger(httpReq.Context()).WithError(internalErr) - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown("unknown error"), - } + return util.JSONResponse{ + Code: code, + JSON: e, + } + case spec.IncompatibleRoomVersionError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: e, + } + default: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), } } @@ -219,6 +216,25 @@ func MakeJoin( } } +type MembershipQuerier struct { + roomserver api.FederationRoomserverAPI +} + +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { + req := api.QueryMembershipForUserRequest{ + RoomID: roomID.String(), + UserID: userID.String(), + } + res := api.QueryMembershipForUserResponse{} + err := mq.roomserver.QueryMembershipForUser(ctx, &req, &res) + + membership := "" + if err == nil { + membership = res.Membership + } + return membership, err +} + // SendJoin implements the /send_join API // The make-join send-join dance makes much more sense as a single // flow so the cyclomatic complexity is high: @@ -229,9 +245,10 @@ func SendJoin( cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, - roomID, eventID string, + roomID spec.RoomID, + eventID string, ) util.JSONResponse { - roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID) + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String()) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") return util.JSONResponse{ @@ -239,132 +256,71 @@ func SendJoin( JSON: spec.InternalServerError{}, } } - verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) - if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.UnsupportedRoomVersion( - fmt.Sprintf("QueryRoomVersionForRoom returned unknown room version: %s", roomVersion), - ), - } - } - event, err := verImpl.NewEventFromUntrustedJSON(request.Content()) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("The request body could not be decoded into valid JSON: " + err.Error()), - } + input := gomatrixserverlib.HandleSendJoinInput{ + Context: httpReq.Context(), + RoomID: roomID, + EventID: eventID, + JoinEvent: request.Content(), + RoomVersion: roomVersion, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + MembershipQuerier: &MembershipQuerier{roomserver: rsAPI}, } - - // Check that a state key is provided. - if event.StateKey() == nil || event.StateKeyEquals("") { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("No state key was provided in the join event."), - } - } - if !event.StateKeyEquals(event.Sender()) { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("Event state key must match the event sender."), - } - } - - // Check that the sender belongs to the server that is sending us - // the request. By this point we've already asserted that the sender - // and the state key are equal so we don't need to check both. - var serverName spec.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("The sender of the join is invalid"), - } - } else if serverName != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("The sender does not match the server that originated the request"), - } - } - - // Check that the room ID is correct. - if event.RoomID() != roomID { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON( - fmt.Sprintf( - "The room ID in the request path (%q) must match the room ID in the join event JSON (%q)", - roomID, event.RoomID(), - ), - ), - } - } - - // Check that the event ID is correct. - if event.EventID() != eventID { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON( - fmt.Sprintf( - "The event ID in the request path (%q) must match the event ID in the join event JSON (%q)", - eventID, event.EventID(), - ), - ), - } - } - - // Check that this is in fact a join event - membership, err := event.Membership() - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("missing content.membership key"), - } - } - if membership != spec.Join { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("membership must be 'join'"), - } - } - - // Check that the event is signed by the server sending the request. - redacted, err := verImpl.RedactEventJSON(event.JSON()) - if err != nil { - logrus.WithError(err).Errorf("XXX: join.go") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("The event JSON could not be redacted"), - } - } - verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, - }} - verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed") + response, joinErr := gomatrixserverlib.HandleSendJoin(input) + switch e := joinErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(joinErr) return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, } - } - if verifyResults[0].Error != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("Signature check failed: " + verifyResults[0].Error.Error()), + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(joinErr) + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorUnsupportedRoomVersion: + code = http.StatusInternalServerError + case spec.ErrorBadJSON: + code = http.StatusBadRequest } + + return util.JSONResponse{ + Code: code, + JSON: e, + } + default: + util.GetLogger(httpReq.Context()).WithError(joinErr) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), + } + } + + if response == nil { + util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeJoin returned invalid response") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } // Fetch the state and auth chain. We do this before we send the events // on, in case this fails. var stateAndAuthChainResponse api.QueryStateAndAuthChainResponse err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{ - PrevEventIDs: event.PrevEventIDs(), - AuthEventIDs: event.AuthEventIDs(), - RoomID: roomID, + PrevEventIDs: response.JoinEvent.PrevEventIDs(), + AuthEventIDs: response.JoinEvent.AuthEventIDs(), + RoomID: roomID.String(), ResolveState: true, }, &stateAndAuthChainResponse) if err != nil { @@ -388,84 +344,27 @@ func SendJoin( } } - // Check if the user is already in the room. If they're already in then - // there isn't much point in sending another join event into the room. - // Also check to see if they are banned: if they are then we reject them. - alreadyJoined := false - isBanned := false - for _, se := range stateAndAuthChainResponse.StateEvents { - if !se.StateKeyEquals(*event.StateKey()) { - continue - } - if membership, merr := se.Membership(); merr == nil { - alreadyJoined = (membership == spec.Join) - isBanned = (membership == spec.Ban) - break - } - } - - if isBanned { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("user is banned"), - } - } - - // If the membership content contains a user ID for a server that is not - // ours then we should kick it back. - var memberContent gomatrixserverlib.MemberContent - if err := json.Unmarshal(event.Content(), &memberContent); err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON(err.Error()), - } - } - if memberContent.AuthorisedVia != "" { - _, domain, err := gomatrixserverlib.SplitID('@', memberContent.AuthorisedVia) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q is invalid.", memberContent.AuthorisedVia)), - } - } - if domain != cfg.Matrix.ServerName { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q does not belong to this server.", memberContent.AuthorisedVia)), - } - } - } - - // Sign the membership event. This is required for restricted joins to work - // in the case that the authorised via user is one of our own users. It also - // doesn't hurt to do it even if it isn't a restricted join. - signed := event.Sign( - string(cfg.Matrix.ServerName), - cfg.Matrix.KeyID, - cfg.Matrix.PrivateKey, - ) - // Send the events to the room server. // We are responsible for notifying other servers that the user has joined // the room, so set SendAsServer to cfg.Matrix.ServerName - if !alreadyJoined { - var response api.InputRoomEventsResponse + if !response.AlreadyJoined { + var rsResponse api.InputRoomEventsResponse rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, - Event: &types.HeaderedEvent{PDU: signed}, + Event: &types.HeaderedEvent{PDU: response.JoinEvent}, SendAsServer: string(cfg.Matrix.ServerName), TransactionID: nil, }, }, - }, &response) - if response.ErrMsg != "" { - util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed") - if response.NotAllowed { + }, &rsResponse) + if rsResponse.ErrMsg != "" { + util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, rsResponse.ErrMsg).Error("SendEvents failed") + if rsResponse.NotAllowed { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.Forbidden(response.ErrMsg), + JSON: spec.Forbidden(rsResponse.ErrMsg), } } return util.JSONResponse{ @@ -488,7 +387,7 @@ func SendJoin( StateEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), AuthEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), Origin: cfg.Matrix.ServerName, - Event: signed.JSON(), + Event: response.JoinEvent.JSON(), }, } } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 44faad918..7be0857a6 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -331,14 +331,14 @@ func Setup( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid UserID"), + JSON: spec.InvalidParam("Invalid UserID"), } } roomID, err := spec.NewRoomID(vars["roomID"]) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid RoomID"), + JSON: spec.InvalidParam("Invalid RoomID"), } } @@ -358,10 +358,17 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + res := SendJoin( - httpReq, request, cfg, rsAPI, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, *roomID, eventID, ) // not all responses get wrapped in [code, body] var body interface{} @@ -390,10 +397,17 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + return SendJoin( - httpReq, request, cfg, rsAPI, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, *roomID, eventID, ) }, )).Methods(http.MethodPut) diff --git a/go.mod b/go.mod index 87fcbf0c0..e85051777 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -34,7 +34,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 - github.com/sirupsen/logrus v1.9.1 + github.com/sirupsen/logrus v1.9.2 github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 9589ed168..6f034fef6 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2 h1:V36yCWt2CoSfI1xr6WYJ9Mb3eyl95SknMRLGFvEuYak= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c h1:EF04pmshcDmBQOrBQbzT5htyTivetfyvR70gX2hB9AM= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.1 h1:Ou41VVR3nMWWmTiEUnj0OlsgOSCUFgsPAOl6jRIcVtQ= -github.com/sirupsen/logrus v1.9.1/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y= +github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index f2e2bf84a..213e16e5d 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -202,6 +202,7 @@ type FederationRoomserverAPI interface { QueryBulkStateContentAPI // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 35cafd0ec..effcc90d7 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -868,7 +868,7 @@ func (r *Queryer) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types } func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { - res, err := r.DB.GetStateEvent(ctx, roomID.String(), string(eventType), "") + res, err := r.DB.GetStateEvent(ctx, roomID.String(), eventType, "") if res == nil { return nil, err } From 5d6221d1917c3494fed57e055e46928aaa4b5bda Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 23 May 2023 19:37:04 +0200 Subject: [PATCH 07/35] Move `MakeLeave` to GMSL (#3085) Basically the same API shape as for `/make_join` https://github.com/matrix-org/gomatrixserverlib/pull/385 --- federationapi/routing/join.go | 10 +- federationapi/routing/leave.go | 161 ++++++++++++++++--------------- federationapi/routing/routing.go | 20 +++- go.mod | 2 +- go.sum | 4 +- 5 files changed, 108 insertions(+), 89 deletions(-) diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index cbdeca51e..4cbfc5e87 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -99,7 +99,7 @@ func MakeJoin( } req := api.QueryServerJoinedToRoomRequest{ - ServerName: cfg.Matrix.ServerName, + ServerName: request.Destination(), RoomID: roomID.String(), } res := api.QueryServerJoinedToRoomResponse{} @@ -162,13 +162,13 @@ func MakeJoin( switch e := internalErr.(type) { case nil: case spec.InternalServerError: - util.GetLogger(httpReq.Context()).WithError(internalErr) + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_join request") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, } case spec.MatrixError: - util.GetLogger(httpReq.Context()).WithError(internalErr) + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_join request") code := http.StatusInternalServerError switch e.ErrCode { case spec.ErrorForbidden: @@ -186,13 +186,13 @@ func MakeJoin( JSON: e, } case spec.IncompatibleRoomVersionError: - util.GetLogger(httpReq.Context()).WithError(internalErr) + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_join request") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: e, } default: - util.GetLogger(httpReq.Context()).WithError(internalErr) + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_join request") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.Unknown("unknown error"), diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index e65403404..3e576e09c 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -34,108 +34,115 @@ func MakeLeave( request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, - roomID, userID string, + roomID spec.RoomID, userID spec.UserID, ) util.JSONResponse { - _, domain, err := gomatrixserverlib.SplitID('@', userID) + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String()) if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid UserID"), - } - } - if domain != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("The leave must be sent by the server of the user"), - } - } - - // Try building an event for the server - proto := gomatrixserverlib.ProtoEvent{ - Sender: userID, - RoomID: roomID, - Type: "m.room.member", - StateKey: &userID, - } - err = proto.SetContent(map[string]interface{}{"membership": spec.Leave}) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("proto.SetContent failed") + util.GetLogger(httpReq.Context()).WithError(err).Error("failed obtaining room version") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, } } - identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) - if err != nil { + req := api.QueryServerJoinedToRoomRequest{ + ServerName: request.Destination(), + RoomID: roomID.String(), + } + res := api.QueryServerJoinedToRoomResponse{} + if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound( - fmt.Sprintf("Server name %q does not exist", request.Destination()), - ), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } - var queryRes api.QueryLatestEventsAndStateResponse - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) - switch e := err.(type) { - case nil: - case eventutil.ErrRoomNoExists: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound("Room does not exist"), + createLeaveTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { + identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) + return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) } - case gomatrixserverlib.BadJSONError: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + + queryRes := api.QueryLatestEventsAndStateResponse{} + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) + switch e := err.(type) { + case nil: + case eventutil.ErrRoomNoExists: + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + return nil, nil, spec.NotFound("Room does not exist") + case gomatrixserverlib.BadJSONError: + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + return nil, nil, spec.BadJSON(e.Error()) + default: + util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + return nil, nil, spec.InternalServerError{} + } + + stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) + for i, stateEvent := range queryRes.StateEvents { + stateEvents[i] = stateEvent.PDU + } + return event, stateEvents, nil + } + + input := gomatrixserverlib.HandleMakeLeaveInput{ + UserID: userID, + RoomID: roomID, + RoomVersion: roomVersion, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + LocalServerInRoom: res.RoomExists && res.IsInRoom, + BuildEventTemplate: createLeaveTemplate, + } + + response, internalErr := gomatrixserverlib.HandleMakeLeave(input) + switch e := internalErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_leave request") return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON(e.Error()), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_leave request") + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorBadJSON: + code = http.StatusBadRequest + } + + return util.JSONResponse{ + Code: code, + JSON: e, } default: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_leave request") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), + } + } + + if response == nil { + util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeLeave returned invalid response") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, } } - // If the user has already left then just return their last leave - // event. This means that /send_leave will be a no-op, which helps - // to reject invites multiple times - hopefully. - for _, state := range queryRes.StateEvents { - if !state.StateKeyEquals(userID) { - continue - } - if mem, merr := state.Membership(); merr == nil && mem == spec.Leave { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: map[string]interface{}{ - "room_version": event.Version(), - "event": state, - }, - } - } - } - - // Check that the leave is allowed or not - stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) - for i := range queryRes.StateEvents { - stateEvents[i] = queryRes.StateEvents[i].PDU - } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(event, &provider); err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden(err.Error()), - } - } - return util.JSONResponse{ Code: http.StatusOK, JSON: map[string]interface{}{ - "room_version": event.Version(), - "event": proto, + "event": response.LeaveTemplateEvent, + "room_version": response.RoomVersion, }, } } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 7be0857a6..fad06c1cf 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -412,7 +412,7 @@ func Setup( }, )).Methods(http.MethodPut) - v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI( + v1fedmux.Handle("/make_leave/{roomID}/{userID}", MakeFedAPI( "federation_make_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { @@ -421,10 +421,22 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] - eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + userID, err := spec.NewUserID(vars["userID"], true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid UserID"), + } + } return MakeLeave( - httpReq, request, cfg, rsAPI, roomID, eventID, + httpReq, request, cfg, rsAPI, *roomID, *userID, ) }, )).Methods(http.MethodGet) diff --git a/go.mod b/go.mod index e85051777..bf2dc5de0 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c + github.com/matrix-org/gomatrixserverlib v0.0.0-20230523164045-3fddabebb511 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 6f034fef6..574a7bd7e 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c h1:EF04pmshcDmBQOrBQbzT5htyTivetfyvR70gX2hB9AM= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230523164045-3fddabebb511 h1:om6z/WEVZMxZfgtiyfp5r5ubAObGMyRrnlVD07gIRY4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230523164045-3fddabebb511/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= From 11b557097c6745309c09b58f681080d3fcc4f9f5 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 24 May 2023 12:14:42 +0200 Subject: [PATCH 08/35] Drop `reference_sha` column (#3083) Companion PR to https://github.com/matrix-org/gomatrixserverlib/pull/383 --- .gitignore | 5 +- build/gobind-pinecone/monolith_test.go | 10 ++- clientapi/routing/createroom.go | 2 +- go.mod | 2 +- go.sum | 4 +- internal/eventutil/events.go | 29 +----- roomserver/api/query.go | 2 +- roomserver/internal/input/input_events.go | 4 +- .../internal/input/input_latest_events.go | 8 +- roomserver/internal/input/input_missing.go | 6 +- roomserver/internal/perform/perform_admin.go | 8 +- .../internal/perform/perform_inbound_peek.go | 2 +- .../internal/perform/perform_upgrade.go | 2 +- roomserver/internal/query/query_test.go | 7 +- roomserver/storage/interface.go | 4 +- .../20230516154000_drop_reference_sha.go | 54 ++++++++++++ roomserver/storage/postgres/events_table.go | 64 ++++---------- .../storage/postgres/previous_events_table.go | 35 +++++--- roomserver/storage/shared/room_updater.go | 16 +--- roomserver/storage/shared/storage.go | 16 ++-- .../20230516154000_drop_reference_sha.go | 72 +++++++++++++++ roomserver/storage/sqlite3/events_table.go | 88 +++++++------------ .../storage/sqlite3/previous_events_table.go | 47 +++++++--- roomserver/storage/sqlite3/storage.go | 3 +- .../storage/tables/events_table_test.go | 13 +-- roomserver/storage/tables/interface.go | 7 +- .../tables/previous_events_table_test.go | 10 +-- roomserver/types/types.go | 2 +- test/room.go | 4 +- 29 files changed, 299 insertions(+), 227 deletions(-) create mode 100644 roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go create mode 100644 roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go diff --git a/.gitignore b/.gitignore index dcfbf8007..043956ee4 100644 --- a/.gitignore +++ b/.gitignore @@ -74,4 +74,7 @@ complement/ docs/_site media_store/ -build \ No newline at end of file +build + +# golang workspaces +go.work* \ No newline at end of file diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index 7a7e36c7e..f16d1d764 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -22,7 +22,10 @@ import ( ) func TestMonolithStarts(t *testing.T) { - monolith := DendriteMonolith{} + monolith := DendriteMonolith{ + StorageDirectory: t.TempDir(), + CacheDirectory: t.TempDir(), + } monolith.Start() monolith.PublicKey() monolith.Stop() @@ -60,7 +63,10 @@ func TestMonolithSetRelayServers(t *testing.T) { } for _, tc := range testCases { - monolith := DendriteMonolith{} + monolith := DendriteMonolith{ + StorageDirectory: t.TempDir(), + CacheDirectory: t.TempDir(), + } monolith.Start() inputRelays := tc.relays diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index bc9600060..7a7a85e85 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -470,7 +470,7 @@ func createRoom( } } if i > 0 { - builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} + builder.PrevEvents = []string{builtEvents[i-1].EventID()} } var ev gomatrixserverlib.PDU if err = builder.AddAuthEvents(&authEvents); err != nil { diff --git a/go.mod b/go.mod index bf2dc5de0..16e5adc8c 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230523164045-3fddabebb511 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230524095531-95ba6c68efb6 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 574a7bd7e..98e7e839d 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230523164045-3fddabebb511 h1:om6z/WEVZMxZfgtiyfp5r5ubAObGMyRrnlVD07gIRY4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230523164045-3fddabebb511/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230524095531-95ba6c68efb6 h1:FQpdh/KGCCQJytz4GAdG6pbx3DJ1HNzdKFc/BCZ0hP0= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230524095531-95ba6c68efb6/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 79882d8d8..ca052c310 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -129,18 +129,12 @@ func addPrevEventsToEvent( return ErrRoomNoExists{} } - verImpl, err := gomatrixserverlib.GetRoomVersion(queryRes.RoomVersion) - if err != nil { - return fmt.Errorf("GetRoomVersion: %w", err) - } - eventFormat := verImpl.EventFormat() - builder.Depth = queryRes.Depth authEvents := gomatrixserverlib.NewAuthEvents(nil) for i := range queryRes.StateEvents { - err = authEvents.AddEvent(queryRes.StateEvents[i].PDU) + err := authEvents.AddEvent(queryRes.StateEvents[i].PDU) if err != nil { return fmt.Errorf("authEvents.AddEvent: %w", err) } @@ -151,22 +145,7 @@ func addPrevEventsToEvent( return fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) } - truncAuth, truncPrev := truncateAuthAndPrevEvents(refs, queryRes.LatestEvents) - switch eventFormat { - case gomatrixserverlib.EventFormatV1: - builder.AuthEvents = truncAuth - builder.PrevEvents = truncPrev - case gomatrixserverlib.EventFormatV2: - v2AuthRefs, v2PrevRefs := []string{}, []string{} - for _, ref := range truncAuth { - v2AuthRefs = append(v2AuthRefs, ref.EventID) - } - for _, ref := range truncPrev { - v2PrevRefs = append(v2PrevRefs, ref.EventID) - } - builder.AuthEvents = v2AuthRefs - builder.PrevEvents = v2PrevRefs - } + builder.AuthEvents, builder.PrevEvents = truncateAuthAndPrevEvents(refs, queryRes.LatestEvents) return nil } @@ -176,8 +155,8 @@ func addPrevEventsToEvent( // NOTSPEC: The limits here feel a bit arbitrary but they are currently // here because of https://github.com/matrix-org/matrix-doc/issues/2307 // and because Synapse will just drop events that don't comply. -func truncateAuthAndPrevEvents(auth, prev []gomatrixserverlib.EventReference) ( - truncAuth, truncPrev []gomatrixserverlib.EventReference, +func truncateAuthAndPrevEvents(auth, prev []string) ( + truncAuth, truncPrev []string, ) { truncAuth, truncPrev = auth, prev if len(truncAuth) > 10 { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 55d1a6dba..1726bfe1f 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -49,7 +49,7 @@ type QueryLatestEventsAndStateResponse struct { RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` // The latest events in the room. // These are used to set the prev_events when sending an event. - LatestEvents []gomatrixserverlib.EventReference `json:"latest_events"` + LatestEvents []string `json:"latest_events"` // The state events requested. // This list will be in an arbitrary order. // These are used to set the auth_events when sending an event. diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index cd78b3722..02a1a2802 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -883,9 +883,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r Origin: senderDomain, SendAsServer: string(senderDomain), }) - prevEvents = []gomatrixserverlib.EventReference{ - event.EventReference(), - } + prevEvents = []string{event.EventID()} } inputReq := &api.InputRoomEventsRequest{ diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 54a5f6234..7a7a021a3 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -154,8 +154,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { extremitiesChanged, err := u.calculateLatest( u.oldLatest, u.event, types.StateAtEventAndReference{ - EventReference: u.event.EventReference(), - StateAtEvent: u.stateAtEvent, + EventID: u.event.EventID(), + StateAtEvent: u.stateAtEvent, }, ) if err != nil { @@ -349,7 +349,7 @@ func (u *latestEventsUpdater) calculateLatest( // If the "new" event is already referenced by an existing event // then do nothing - it's not a candidate to be a new extremity if // it has been referenced. - if referenced, err := u.updater.IsReferenced(newEvent.EventReference()); err != nil { + if referenced, err := u.updater.IsReferenced(newEvent.EventID()); err != nil { return false, fmt.Errorf("u.updater.IsReferenced(new): %w", err) } else if referenced { u.latest = oldLatest @@ -360,7 +360,7 @@ func (u *latestEventsUpdater) calculateLatest( // have entries in the previous events table. If they do then we // will no longer include them as forward extremities. for k, l := range existingRefs { - referenced, err := u.updater.IsReferenced(l.EventReference) + referenced, err := u.updater.IsReferenced(l.EventID) if err != nil { return false, fmt.Errorf("u.updater.IsReferenced: %w", err) } else if referenced { diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 8a1235221..10486138d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -520,9 +520,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err) } latestEvents := make([]string, len(latest)) - for i, ev := range latest { - latestEvents[i] = ev.EventID - t.hadEvent(ev.EventID) + for i := range latest { + latestEvents[i] = latest[i] + t.hadEvent(latest[i]) } var missingResp *fclient.RespMissingEvents diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index fadc8bcfc..17296febc 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -131,9 +131,7 @@ func (r *Admin) PerformAdminEvacuateRoom( SendAsServer: string(senderDomain), }) affected = append(affected, stateKey) - prevEvents = []gomatrixserverlib.EventReference{ - event.EventReference(), - } + prevEvents = []string{event.EventID()} } inputReq := &api.InputRoomEventsRequest{ @@ -253,9 +251,9 @@ func (r *Admin) PerformAdminDownloadState( for _, fwdExtremity := range fwdExtremities { var state gomatrixserverlib.StateResponse - state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, serverName, roomID, fwdExtremity.EventID, roomInfo.RoomVersion) + state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, serverName, roomID, fwdExtremity, roomInfo.RoomVersion) if err != nil { - return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity.EventID, err) + return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 3094a17fd..3ac0f6f4d 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID}) + latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0]}) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index abe63145a..60085cb6d 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -471,7 +471,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user return fmt.Errorf("failed to set content of new %q event: %w", proto.Type, err) } if i > 0 { - proto.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} + proto.PrevEvents = []string{builtEvents[i-1].EventID()} } var verImpl gomatrixserverlib.IRoomVersion diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index b6715cb00..619d93030 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -43,13 +43,10 @@ func createEventDB() *getEventDB { // Adds a fake event to the storage with given auth events. func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error { - authEvents := []gomatrixserverlib.EventReference{} + authEvents := make([]any, 0, len(authIDs)) for _, authID := range authIDs { - authEvents = append(authEvents, gomatrixserverlib.EventReference{ - EventID: authID, - }) + authEvents = append(authEvents, []any{authID, struct{}{}}) } - builder := map[string]interface{}{ "event_id": eventID, "auth_events": authEvents, diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 6bc4ce9ab..7d22df008 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -102,7 +102,7 @@ type Database interface { // Look up event references for the latest events in the room and the current state snapshot. // Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns an error if there was a problem talking to the database. - LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]string, types.StateSnapshotNID, int64, error) // Look up the active invites targeting a user in a room and return the // numeric state key IDs for the user IDs who sent them along with the event IDs for the invites. // Returns an error if there was a problem talking to the database. @@ -206,7 +206,7 @@ type RoomDatabase interface { BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]string, types.StateSnapshotNID, int64, error) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) diff --git a/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go new file mode 100644 index 000000000..c19577713 --- /dev/null +++ b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go @@ -0,0 +1,54 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpDropEventReferenceSHAEvents(ctx context.Context, tx *sql.Tx) error { + var count int + err := tx.QueryRowContext(ctx, `SELECT count(*) FROM roomserver_events GROUP BY event_id HAVING count(event_id) > 1`). + Scan(&count) + if err != nil && err != sql.ErrNoRows { + return fmt.Errorf("failed to query duplicate event ids") + } + if count > 0 { + return fmt.Errorf("unable to drop column, as there are duplicate event ids") + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_events DROP COLUMN IF EXISTS reference_sha256;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func UpDropEventReferenceSHAPrevEvents(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE roomserver_previous_events DROP CONSTRAINT roomserver_previous_event_id_unique;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_previous_events DROP COLUMN IF EXISTS previous_reference_sha256;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_previous_events ADD CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id);`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c935608a5..a00b4b1d7 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,10 +22,9 @@ import ( "sort" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -62,9 +61,6 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( -- Needed for state resolution. -- An event may only appear in this table once. event_id TEXT NOT NULL CONSTRAINT roomserver_event_id_unique UNIQUE, - -- The sha256 reference hash for the event. - -- Needed for setting reference hashes when sending new events. - reference_sha256 BYTEA NOT NULL, -- A list of numeric IDs for events that can authenticate this event. auth_event_nids BIGINT[] NOT NULL, is_rejected BOOLEAN NOT NULL DEFAULT FALSE @@ -75,10 +71,10 @@ CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_event ` const insertEventSQL = "" + - "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, auth_event_nids, depth, is_rejected)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + - " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = TRUE" + + " SET is_rejected = $7 WHERE e.event_id = $4 AND e.is_rejected = TRUE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + @@ -130,12 +126,9 @@ const selectEventIDSQL = "" + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" const bulkSelectStateAtEventAndReferenceSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id" + " FROM roomserver_events WHERE event_nid = ANY($1)" -const bulkSelectEventReferenceSQL = "" + - "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid = ANY($1)" - const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)" @@ -167,7 +160,6 @@ type eventStatements struct { updateEventSentToOutputStmt *sql.Stmt selectEventIDStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt bulkSelectUnsentEventNIDStmt *sql.Stmt @@ -178,7 +170,18 @@ type eventStatements struct { func CreateEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) - return err + if err != nil { + return err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations([]sqlutil.Migration{ + { + Version: "roomserver: drop column reference_sha from roomserver_events", + Up: deltas.UpDropEventReferenceSHAEvents, + }, + }...) + return m.Up(context.Background()) } func PrepareEventsTable(db *sql.DB) (tables.Events, error) { @@ -197,7 +200,6 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, {&s.selectEventIDStmt, selectEventIDSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, - {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, @@ -214,7 +216,6 @@ func (s *eventStatements) InsertEvent( eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, eventID string, - referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, @@ -224,7 +225,7 @@ func (s *eventStatements) InsertEvent( stmt := sqlutil.TxStmt(txn, s.insertEventStmt) err := stmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + eventID, eventNIDsAsArray(authEventNIDs), depth, isRejected, ).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err @@ -441,11 +442,10 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( eventNID int64 stateSnapshotNID int64 eventID string - eventSHA256 []byte ) for ; rows.Next(); i++ { if err = rows.Scan( - &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, + &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, ); err != nil { return nil, err } @@ -455,32 +455,6 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( result.EventNID = types.EventNID(eventNID) result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) result.EventID = eventID - result.EventSHA256 = eventSHA256 - } - if err = rows.Err(); err != nil { - return nil, err - } - if i != len(eventNIDs) { - return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) - } - return results, nil -} - -func (s *eventStatements) BulkSelectEventReference( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, -) ([]gomatrixserverlib.EventReference, error) { - rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") - results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) - i := 0 - for ; rows.Next(); i++ { - result := &results[i] - if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { - return nil, err - } } if err = rows.Err(); err != nil { return nil, err diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index 26999a290..ceb5e26ba 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -32,11 +33,9 @@ const previousEventSchema = ` CREATE TABLE IF NOT EXISTS roomserver_previous_events ( -- The string event ID taken from the prev_events key of an event. previous_event_id TEXT NOT NULL, - -- The SHA256 reference hash taken from the prev_events key of an event. - previous_reference_sha256 BYTEA NOT NULL, -- A list of numeric event IDs of events that reference this prev_event. event_nids BIGINT[] NOT NULL, - CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id, previous_reference_sha256) + CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id) ); ` @@ -47,17 +46,17 @@ CREATE TABLE IF NOT EXISTS roomserver_previous_events ( // The lock is necessary to avoid data races when checking whether an event is already referenced by another event. const insertPreviousEventSQL = "" + "INSERT INTO roomserver_previous_events" + - " (previous_event_id, previous_reference_sha256, event_nids)" + - " VALUES ($1, $2, array_append('{}'::bigint[], $3))" + + " (previous_event_id, event_nids)" + + " VALUES ($1, array_append('{}'::bigint[], $2))" + " ON CONFLICT ON CONSTRAINT roomserver_previous_event_id_unique" + - " DO UPDATE SET event_nids = array_append(roomserver_previous_events.event_nids, $3)" + - " WHERE $3 != ALL(roomserver_previous_events.event_nids)" + " DO UPDATE SET event_nids = array_append(roomserver_previous_events.event_nids, $2)" + + " WHERE $2 != ALL(roomserver_previous_events.event_nids)" // Check if the event is referenced by another event in the table. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. const selectPreviousEventExistsSQL = "" + "SELECT 1 FROM roomserver_previous_events" + - " WHERE previous_event_id = $1 AND previous_reference_sha256 = $2" + " WHERE previous_event_id = $1" type previousEventStatements struct { insertPreviousEventStmt *sql.Stmt @@ -66,7 +65,18 @@ type previousEventStatements struct { func CreatePrevEventsTable(db *sql.DB) error { _, err := db.Exec(previousEventSchema) - return err + if err != nil { + return err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations([]sqlutil.Migration{ + { + Version: "roomserver: drop column reference_sha from roomserver_prev_events", + Up: deltas.UpDropEventReferenceSHAPrevEvents, + }, + }...) + return m.Up(context.Background()) } func PreparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { @@ -82,12 +92,11 @@ func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, - previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ctx, previousEventID, int64(eventNID), ) return err } @@ -95,9 +104,9 @@ func (s *previousEventStatements) InsertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. func (s *previousEventStatements) SelectPreviousEventExists( - ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, + ctx context.Context, txn *sql.Tx, eventID string, ) error { var ok int64 stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) + return stmt.QueryRowContext(ctx, eventID).Scan(&ok) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 5a20c67b3..70672a33e 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -104,18 +104,6 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { return u.currentStateSnapshotNID } -// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer -func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) - } - } - return nil - }) -} - func (u *RoomUpdater) Events(ctx context.Context, _ gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) { if u.roomInfo == nil { return nil, types.ErrorInvalidRoomInfo @@ -203,8 +191,8 @@ func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInf } // IsReferenced implements types.RoomRecentEventsUpdater -func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) +func (u *RoomUpdater) IsReferenced(eventID string) (bool, error) { + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventID) if err == nil { return true, nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 60e46c478..cefa58a3d 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -398,15 +398,13 @@ func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo return d.events(ctx, txn, roomInfo.RoomVersion, nids) } -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { +func (d *Database) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) (references []string, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { var eventNIDs []types.EventNID eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID) if err != nil { return } - references, err = d.EventsTable.BulkSelectEventReference(ctx, nil, eventNIDs) + eventNIDMap, err := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { return } @@ -414,6 +412,9 @@ func (d *Database) LatestEventIDs( if err != nil { return } + for _, eventID := range eventNIDMap { + references = append(references, eventID) + } return } @@ -742,7 +743,6 @@ func (d *EventDatabase) StoreEvent( eventTypeNID, eventStateKeyNID, event.EventID(), - event.EventReference().EventSHA256, authEventNIDs, event.Depth(), isRejected, @@ -762,7 +762,7 @@ func (d *EventDatabase) StoreEvent( return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { + if prevEvents := event.PrevEventIDs(); len(prevEvents) > 0 { // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // function only does SELECTs though so the created txn (at this point) is just a read txn like @@ -770,8 +770,8 @@ func (d *EventDatabase) StoreEvent( // to do writes however then this will need to go inside `Writer.Do`. // The following is a copy of RoomUpdater.StorePreviousEvents - for _, ref := range prevEvents { - if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + for _, eventID := range prevEvents { + if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, eventID, eventNID); err != nil { return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) } } diff --git a/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go new file mode 100644 index 000000000..452d72ace --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go @@ -0,0 +1,72 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpDropEventReferenceSHA(ctx context.Context, tx *sql.Tx) error { + var count int + err := tx.QueryRowContext(ctx, `SELECT count(*) FROM roomserver_events GROUP BY event_id HAVING count(event_id) > 1`). + Scan(&count) + if err != nil && err != sql.ErrNoRows { + return fmt.Errorf("failed to query duplicate event ids") + } + if count > 0 { + return fmt.Errorf("unable to drop column, as there are duplicate event ids") + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_events DROP COLUMN reference_sha256;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func UpDropEventReferenceSHAPrevEvents(ctx context.Context, tx *sql.Tx) error { + // rename the table + if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_previous_events RENAME TO _roomserver_previous_events;`); err != nil { + return fmt.Errorf("tx.ExecContext: %w", err) + } + + // create new table + if _, err := tx.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + event_nids TEXT NOT NULL, + UNIQUE (previous_event_id) + );`); err != nil { + return fmt.Errorf("tx.ExecContext: %w", err) + } + + // move data + if _, err := tx.ExecContext(ctx, ` +INSERT + INTO roomserver_previous_events ( + previous_event_id, event_nids + ) SELECT + previous_event_id, event_nids + FROM _roomserver_previous_events +;`); err != nil { + return fmt.Errorf("tx.ExecContext: %w", err) + } + // drop old table + _, err := tx.ExecContext(ctx, `DROP TABLE _roomserver_previous_events;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index aacf4bc9a..c49c6dc38 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -19,14 +19,14 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "sort" "strings" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -41,17 +41,16 @@ const eventsSchema = ` state_snapshot_nid INTEGER NOT NULL DEFAULT 0, depth INTEGER NOT NULL, event_id TEXT NOT NULL UNIQUE, - reference_sha256 BLOB NOT NULL, auth_event_nids TEXT NOT NULL DEFAULT '[]', is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); ` const insertEventSQL = ` - INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, auth_event_nids, depth, is_rejected) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT DO UPDATE - SET is_rejected = $8 WHERE is_rejected = 1 + SET is_rejected = $7 WHERE is_rejected = 1 RETURNING event_nid, state_snapshot_nid; ` @@ -100,12 +99,9 @@ const selectEventIDSQL = "" + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" const bulkSelectStateAtEventAndReferenceSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id" + " FROM roomserver_events WHERE event_nid IN ($1)" -const bulkSelectEventReferenceSQL = "" + - "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" - const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" @@ -137,7 +133,6 @@ type eventStatements struct { updateEventSentToOutputStmt *sql.Stmt selectEventIDStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt selectEventRejectedStmt *sql.Stmt //bulkSelectEventNIDStmt *sql.Stmt @@ -147,7 +142,32 @@ type eventStatements struct { func CreateEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) - return err + if err != nil { + return err + } + + // check if the column exists + var cName string + migrationName := "roomserver: drop column reference_sha from roomserver_events" + err = db.QueryRowContext(context.Background(), `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'roomserver_events' AND p.name = 'reference_sha256'`).Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations([]sqlutil.Migration{ + { + Version: migrationName, + Up: deltas.UpDropEventReferenceSHA, + }, + }...) + return m.Up(context.Background()) } func PrepareEventsTable(db *sql.DB) (tables.Events, error) { @@ -167,7 +187,6 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, {&s.selectEventIDStmt, selectEventIDSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, - {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, //{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, @@ -183,7 +202,6 @@ func (s *eventStatements) InsertEvent( eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, eventID string, - referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, @@ -194,7 +212,7 @@ func (s *eventStatements) InsertEvent( insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) err := insertStmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, + eventID, eventNIDsAsArray(authEventNIDs), depth, isRejected, ).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } @@ -475,11 +493,10 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( eventNID int64 stateSnapshotNID int64 eventID string - eventSHA256 []byte ) for ; rows.Next(); i++ { if err = rows.Scan( - &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, + &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, ); err != nil { return nil, err } @@ -489,43 +506,6 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( result.EventNID = types.EventNID(eventNID) result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) result.EventID = eventID - result.EventSHA256 = eventSHA256 - } - if i != len(eventNIDs) { - return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) - } - return results, nil -} - -func (s *eventStatements) BulkSelectEventReference( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, -) ([]gomatrixserverlib.EventReference, error) { - /////////////// - iEventNIDs := make([]interface{}, len(eventNIDs)) - for k, v := range eventNIDs { - iEventNIDs[k] = v - } - selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - defer selectPrep.Close() // nolint:errcheck - /////////////// - - selectStmt := sqlutil.TxStmt(txn, selectPrep) - rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") - results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) - i := 0 - for ; rows.Next(); i++ { - result := &results[i] - if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { - return nil, err - } } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 2a146ef64..4e59fbba7 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -18,10 +18,12 @@ package sqlite3 import ( "context" "database/sql" + "errors" "fmt" "strings" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -34,9 +36,8 @@ import ( const previousEventSchema = ` CREATE TABLE IF NOT EXISTS roomserver_previous_events ( previous_event_id TEXT NOT NULL, - previous_reference_sha256 BLOB, event_nids TEXT NOT NULL, - UNIQUE (previous_event_id, previous_reference_sha256) + UNIQUE (previous_event_id) ); ` @@ -47,20 +48,20 @@ const previousEventSchema = ` // The lock is necessary to avoid data races when checking whether an event is already referenced by another event. const insertPreviousEventSQL = ` INSERT OR REPLACE INTO roomserver_previous_events - (previous_event_id, previous_reference_sha256, event_nids) - VALUES ($1, $2, $3) + (previous_event_id, event_nids) + VALUES ($1, $2) ` const selectPreviousEventNIDsSQL = ` SELECT event_nids FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + WHERE previous_event_id = $1 ` // Check if the event is referenced by another event in the table. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. const selectPreviousEventExistsSQL = ` SELECT 1 FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + WHERE previous_event_id = $1 ` type previousEventStatements struct { @@ -72,7 +73,30 @@ type previousEventStatements struct { func CreatePrevEventsTable(db *sql.DB) error { _, err := db.Exec(previousEventSchema) - return err + if err != nil { + return err + } + // check if the column exists + var cName string + migrationName := "roomserver: drop column reference_sha from roomserver_prev_events" + err = db.QueryRowContext(context.Background(), `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'roomserver_previous_events' AND p.name = 'previous_reference_sha256'`).Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations([]sqlutil.Migration{ + { + Version: migrationName, + Up: deltas.UpDropEventReferenceSHAPrevEvents, + }, + }...) + return m.Up(context.Background()) } func PreparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { @@ -91,13 +115,12 @@ func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, - previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { var eventNIDs string eventNIDAsString := fmt.Sprintf("%d", eventNID) selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs) + err := selectStmt.QueryRowContext(ctx, previousEventID).Scan(&eventNIDs) if err != nil && err != sql.ErrNoRows { return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) } @@ -115,7 +138,7 @@ func (s *previousEventStatements) InsertPreviousEvent( } insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err = insertStmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, eventNIDs, + ctx, previousEventID, eventNIDs, ) return err } @@ -123,9 +146,9 @@ func (s *previousEventStatements) InsertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. func (s *previousEventStatements) SelectPreviousEventExists( - ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, + ctx context.Context, txn *sql.Tx, eventID string, ) error { var ok int64 stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) + return stmt.QueryRowContext(ctx, eventID).Scan(&ok) } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 89e16fc14..6ab427a84 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -21,14 +21,13 @@ import ( "errors" "fmt" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) // A Database is used to store room events and stream offsets. diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go index 107af4784..5ed805648 100644 --- a/roomserver/storage/tables/events_table_test.go +++ b/roomserver/storage/tables/events_table_test.go @@ -11,7 +11,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) @@ -48,10 +47,9 @@ func Test_EventsTable(t *testing.T) { // create some dummy data eventIDs := make([]string, 0, len(room.Events())) wantStateAtEvent := make([]types.StateAtEvent, 0, len(room.Events())) - wantEventReferences := make([]gomatrixserverlib.EventReference, 0, len(room.Events())) wantStateAtEventAndRefs := make([]types.StateAtEventAndReference, 0, len(room.Events())) for _, ev := range room.Events() { - eventNID, snapNID, err := tab.InsertEvent(ctx, nil, 1, 1, 1, ev.EventID(), ev.EventReference().EventSHA256, nil, ev.Depth(), false) + eventNID, snapNID, err := tab.InsertEvent(ctx, nil, 1, 1, 1, ev.EventID(), nil, ev.Depth(), false) assert.NoError(t, err) gotEventNID, gotSnapNID, err := tab.SelectEvent(ctx, nil, ev.EventID()) assert.NoError(t, err) @@ -75,7 +73,6 @@ func Test_EventsTable(t *testing.T) { assert.True(t, sentToOutput) eventIDs = append(eventIDs, ev.EventID()) - wantEventReferences = append(wantEventReferences, ev.EventReference()) // Set the stateSnapshot to 2 for some events to verify they are returned later stateSnapshot := 0 @@ -97,8 +94,8 @@ func Test_EventsTable(t *testing.T) { } wantStateAtEvent = append(wantStateAtEvent, stateAtEvent) wantStateAtEventAndRefs = append(wantStateAtEventAndRefs, types.StateAtEventAndReference{ - StateAtEvent: stateAtEvent, - EventReference: ev.EventReference(), + StateAtEvent: stateAtEvent, + EventID: ev.EventID(), }) } @@ -140,10 +137,6 @@ func Test_EventsTable(t *testing.T) { assert.True(t, ok) } - references, err := tab.BulkSelectEventReference(ctx, nil, nids) - assert.NoError(t, err) - assert.Equal(t, wantEventReferences, references) - stateAndRefs, err := tab.BulkSelectStateAtEventAndReference(ctx, nil, nids) assert.NoError(t, err) assert.Equal(t, wantStateAtEventAndRefs, stateAndRefs) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index a9c5f8b11..333483b32 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -42,7 +42,7 @@ type Events interface { InsertEvent( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, eventID string, - referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, + authEventNIDs []types.EventNID, depth int64, isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[types.StateSnapshotNID][]string, error) @@ -59,7 +59,6 @@ type Events interface { UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error SelectEventID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) - BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) // BulkSelectEventID returns a map from numeric event ID to string event ID. BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. @@ -113,10 +112,10 @@ type RoomAliases interface { } type PreviousEvents interface { - InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error + InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, eventNID types.EventNID) error // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. - SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error + SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string) error } type Invites interface { diff --git a/roomserver/storage/tables/previous_events_table_test.go b/roomserver/storage/tables/previous_events_table_test.go index 63d540696..9d41e90be 100644 --- a/roomserver/storage/tables/previous_events_table_test.go +++ b/roomserver/storage/tables/previous_events_table_test.go @@ -45,17 +45,17 @@ func TestPreviousEventsTable(t *testing.T) { defer close() for _, x := range room.Events() { - for _, prevEvent := range x.PrevEvents() { - err := tab.InsertPreviousEvent(ctx, nil, prevEvent.EventID, prevEvent.EventSHA256, 1) + for _, eventID := range x.PrevEventIDs() { + err := tab.InsertPreviousEvent(ctx, nil, eventID, 1) assert.NoError(t, err) - err = tab.SelectPreviousEventExists(ctx, nil, prevEvent.EventID, prevEvent.EventSHA256) + err = tab.SelectPreviousEventExists(ctx, nil, eventID) assert.NoError(t, err) } } - // RandomString with a correct EventSHA256 should fail and return sql.ErrNoRows - err := tab.SelectPreviousEventExists(ctx, nil, util.RandomString(16), room.Events()[0].EventReference().EventSHA256) + // RandomString should fail and return sql.ErrNoRows + err := tab.SelectPreviousEventExists(ctx, nil, util.RandomString(16)) assert.Error(t, err) }) } diff --git a/roomserver/types/types.go b/roomserver/types/types.go index e986b9da7..f57978ad5 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -200,7 +200,7 @@ func (s StateAtEvent) IsStateEvent() bool { // The StateAtEvent is used to construct the current state of the room from the latest events. type StateAtEventAndReference struct { StateAtEvent - gomatrixserverlib.EventReference + EventID string } type StateAtEventAndReferences []StateAtEventAndReference diff --git a/test/room.go b/test/room.go index 1c0f01e4b..852e31533 100644 --- a/test/room.go +++ b/test/room.go @@ -75,7 +75,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { return r } -func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference { +func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []string { t.Helper() a, err := needed.AuthEventReferences(&r.authEvents) if err != nil { @@ -176,7 +176,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten t.Fatalf("CreateEvent[%s]: failed to SetContent: %s", eventType, err) } if depth > 1 { - builder.PrevEvents = []gomatrixserverlib.EventReference{r.events[len(r.events)-1].EventReference()} + builder.PrevEvents = []string{r.events[len(r.events)-1].EventID()} } err = builder.AddAuthEvents(&r.authEvents) From f956a8c1d9172f6bbfb9f7515feacd477a0e35f5 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 30 May 2023 10:02:53 +0200 Subject: [PATCH 09/35] Docs restructure (#2953) Needs to be merged into `gh-pages` later on. --- .github/workflows/helm.yml | 1 + README.md | 8 +- build.cmd | 51 ------ build.sh | 24 --- build/docker/README.md | 17 +- build/docker/docker-compose.monolith.yml | 44 ----- build/docker/docker-compose.yml | 52 ++++++ build/docker/postgres/create_db.sh | 5 - clientapi/admin_test.go | 6 +- clientapi/routing/admin.go | 5 +- dendrite-sample.yaml | 3 +- docs/FAQ.md | 6 +- docs/INSTALL.md | 10 +- docs/administration/1_createusers.md | 5 +- docs/administration/4_adminapi.md | 18 +- .../5_optimisation.md} | 11 +- ...roubleshooting.md => 6_troubleshooting.md} | 12 +- docs/caddy/{monolith => }/Caddyfile | 0 docs/caddy/polylith/Caddyfile | 85 --------- docs/development/CONTRIBUTING.md | 1 + docs/development/PROFILING.md | 1 + docs/development/coverage.md | 164 ++++++++++++------ docs/development/sytest.md | 70 +------- docs/development/tracing/opentracing.md | 114 ------------ docs/development/tracing/setup.md | 57 ------ ...olith-sample.conf => dendrite-sample.conf} | 0 docs/hiawatha/polylith-sample.conf | 35 ---- docs/installation/1_planning.md | 21 +-- docs/installation/2_domainname.md | 2 +- docs/installation/5_install_monolith.md | 21 --- docs/installation/9_starting_monolith.md | 42 ----- docs/installation/docker.md | 11 ++ docs/installation/docker/1_docker.md | 57 ++++++ docs/installation/helm.md | 11 ++ docs/installation/helm/1_helm.md | 58 +++++++ docs/installation/manual.md | 11 ++ .../{3_build.md => manual/1_build.md} | 23 +-- .../{4_database.md => manual/2_database.md} | 52 ++---- .../3_configuration.md} | 62 ++----- .../4_signingkey.md} | 9 +- .../manual/5_starting_dendrite.md | 26 +++ ...olith-sample.conf => dendrite-sample.conf} | 0 docs/nginx/polylith-sample.conf | 58 ------- docs/systemd/monolith-example.service | 19 -- roomserver/internal/perform/perform_admin.go | 12 +- syncapi/storage/sqlite3/account_data_table.go | 2 +- 46 files changed, 447 insertions(+), 855 deletions(-) delete mode 100644 build.cmd delete mode 100755 build.sh delete mode 100644 build/docker/docker-compose.monolith.yml create mode 100644 build/docker/docker-compose.yml delete mode 100755 build/docker/postgres/create_db.sh rename docs/{installation/11_optimisation.md => administration/5_optimisation.md} (90%) rename docs/administration/{5_troubleshooting.md => 6_troubleshooting.md} (88%) rename docs/caddy/{monolith => }/Caddyfile (100%) delete mode 100644 docs/caddy/polylith/Caddyfile delete mode 100644 docs/development/tracing/opentracing.md delete mode 100644 docs/development/tracing/setup.md rename docs/hiawatha/{monolith-sample.conf => dendrite-sample.conf} (100%) delete mode 100644 docs/hiawatha/polylith-sample.conf delete mode 100644 docs/installation/5_install_monolith.md delete mode 100644 docs/installation/9_starting_monolith.md create mode 100644 docs/installation/docker.md create mode 100644 docs/installation/docker/1_docker.md create mode 100644 docs/installation/helm.md create mode 100644 docs/installation/helm/1_helm.md create mode 100644 docs/installation/manual.md rename docs/installation/{3_build.md => manual/1_build.md} (53%) rename docs/installation/{4_database.md => manual/2_database.md} (57%) rename docs/installation/{7_configuration.md => manual/3_configuration.md} (67%) rename docs/installation/{8_signingkey.md => manual/4_signingkey.md} (92%) create mode 100644 docs/installation/manual/5_starting_dendrite.md rename docs/nginx/{monolith-sample.conf => dendrite-sample.conf} (100%) delete mode 100644 docs/nginx/polylith-sample.conf delete mode 100644 docs/systemd/monolith-example.service diff --git a/.github/workflows/helm.yml b/.github/workflows/helm.yml index a9c1718a0..bf62a1c19 100644 --- a/.github/workflows/helm.yml +++ b/.github/workflows/helm.yml @@ -38,3 +38,4 @@ jobs: with: config: helm/cr.yaml charts_dir: helm/ + mark_as_latest: false diff --git a/README.md b/README.md index 295203eb4..0b9788768 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ It intends to provide an **efficient**, **reliable** and **scalable** alternativ Dendrite is **beta** software, which means: -- Dendrite is ready for early adopters. We recommend running in Monolith mode with a PostgreSQL database. +- Dendrite is ready for early adopters. We recommend running Dendrite with a PostgreSQL database. - Dendrite has periodic releases. We intend to release new versions as we fix bugs and land significant features. - Dendrite supports database schema upgrades between releases. This means you should never lose your messages when upgrading Dendrite. @@ -21,7 +21,7 @@ This does not mean: - Dendrite is bug-free. It has not yet been battle-tested in the real world and so will be error prone initially. - Dendrite is feature-complete. There may be client or federation APIs that are not implemented. -- Dendrite is ready for massive homeserver deployments. There is no sharding of microservices (although it is possible to run them on separate machines) and there is no high-availability/clustering support. +- Dendrite is ready for massive homeserver deployments. There is no high-availability/clustering support. Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices. @@ -60,7 +60,7 @@ The following instructions are enough to get Dendrite started as a non-federatin ```bash $ git clone https://github.com/matrix-org/dendrite $ cd dendrite -$ ./build.sh +$ go build -o bin/ ./cmd/... # Generate a Matrix signing key for federation (required) $ ./bin/generate-keys --private-key matrix_key.pem @@ -85,7 +85,7 @@ Then point your favourite Matrix client at `http://localhost:8008` or `https://l ## Progress -We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver +We use a script called "Are We Synapse Yet" which checks Sytest compliance rates. Sytest is a black-box homeserver test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it updates with CI. As of January 2023, we have 100% server-server parity with Synapse, and the client-server parity is at 93% , though check CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse diff --git a/build.cmd b/build.cmd deleted file mode 100644 index 9e90622c8..000000000 --- a/build.cmd +++ /dev/null @@ -1,51 +0,0 @@ -@echo off - -:ENTRY_POINT - setlocal EnableDelayedExpansion - - REM script base dir - set SCRIPTDIR=%~dp0 - set PROJDIR=%SCRIPTDIR:~0,-1% - - REM Put installed packages into ./bin - set GOBIN=%PROJDIR%\bin - - set FLAGS= - - REM Check if sources are under Git control - if not exist ".git" goto :CHECK_BIN - - REM set BUILD=`git rev-parse --short HEAD \\ ""` - FOR /F "tokens=*" %%X IN ('git rev-parse --short HEAD') DO ( - set BUILD=%%X - ) - - REM set BRANCH=`(git symbolic-ref --short HEAD \ tr -d \/ ) \\ ""` - FOR /F "tokens=*" %%X IN ('git symbolic-ref --short HEAD') DO ( - set BRANCHRAW=%%X - set BRANCH=!BRANCHRAW:/=! - ) - if "%BRANCH%" == "main" set BRANCH= - - set FLAGS=-X github.com/matrix-org/dendrite/internal.branch=%BRANCH% -X github.com/matrix-org/dendrite/internal.build=%BUILD% - -:CHECK_BIN - if exist "bin" goto :ALL_SET - mkdir "bin" - -:ALL_SET - set CGO_ENABLED=1 - for /D %%P in (cmd\*) do ( - go build -trimpath -ldflags "%FLAGS%" -v -o ".\bin" ".\%%P" - ) - - set CGO_ENABLED=0 - set GOOS=js - set GOARCH=wasm - go build -trimpath -ldflags "%FLAGS%" -o bin\main.wasm .\cmd\dendritejs-pinecone - - goto :DONE - -:DONE - echo Done - endlocal \ No newline at end of file diff --git a/build.sh b/build.sh deleted file mode 100755 index f8b5001bf..000000000 --- a/build.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/sh -eu - -# Put installed packages into ./bin -export GOBIN=$PWD/`dirname $0`/bin - -if [ -d ".git" ] -then - export BUILD=`git rev-parse --short HEAD || ""` - export BRANCH=`(git symbolic-ref --short HEAD | tr -d \/ ) || ""` - if [ "$BRANCH" = main ] - then - export BRANCH="" - fi - - export FLAGS="-X github.com/matrix-org/dendrite/internal.branch=$BRANCH -X github.com/matrix-org/dendrite/internal.build=$BUILD" -else - export FLAGS="" -fi - -mkdir -p bin - -CGO_ENABLED=1 go build -trimpath -ldflags "$FLAGS" -v -o "bin/" ./cmd/... - -# CGO_ENABLED=0 GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o bin/main.wasm ./cmd/dendritejs-pinecone diff --git a/build/docker/README.md b/build/docker/README.md index b66cb864b..8d69b9af1 100644 --- a/build/docker/README.md +++ b/build/docker/README.md @@ -6,23 +6,20 @@ They can be found on Docker Hub: - [matrixdotorg/dendrite-monolith](https://hub.docker.com/r/matrixdotorg/dendrite-monolith) for monolith deployments -## Dockerfiles +## Dockerfile -The `Dockerfile` is a multistage file which can build all four Dendrite -images depending on the supplied `--target`. From the root of the Dendrite +The `Dockerfile` is a multistage file which can build Dendrite. From the root of the Dendrite repository, run: ``` -docker build . --target monolith -t matrixdotorg/dendrite-monolith -docker build . --target demo-pinecone -t matrixdotorg/dendrite-demo-pinecone -docker build . --target demo-yggdrasil -t matrixdotorg/dendrite-demo-yggdrasil +docker build . -t matrixdotorg/dendrite-monolith ``` -## Compose files +## Compose file -There are two sample `docker-compose` files: +There is one sample `docker-compose` files: -- `docker-compose.monolith.yml` which runs a monolith Dendrite deployment +- `docker-compose.yml` which runs a Dendrite deployment with Postgres ## Configuration @@ -55,7 +52,7 @@ Create your config based on the [`dendrite-sample.yaml`](https://github.com/matr Then start the deployment: ``` -docker-compose -f docker-compose.monolith.yml up +docker-compose -f docker-compose.yml up ``` ## Building the images diff --git a/build/docker/docker-compose.monolith.yml b/build/docker/docker-compose.monolith.yml deleted file mode 100644 index 1a8fe4eee..000000000 --- a/build/docker/docker-compose.monolith.yml +++ /dev/null @@ -1,44 +0,0 @@ -version: "3.4" -services: - postgres: - hostname: postgres - image: postgres:14 - restart: always - volumes: - - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh - # To persist your PostgreSQL databases outside of the Docker image, - # to prevent data loss, modify the following ./path_to path: - - ./path_to/postgresql:/var/lib/postgresql/data - environment: - POSTGRES_PASSWORD: itsasecret - POSTGRES_USER: dendrite - healthcheck: - test: ["CMD-SHELL", "pg_isready -U dendrite"] - interval: 5s - timeout: 5s - retries: 5 - networks: - - internal - - monolith: - hostname: monolith - image: matrixdotorg/dendrite-monolith:latest - command: [ - "--tls-cert=server.crt", - "--tls-key=server.key" - ] - ports: - - 8008:8008 - - 8448:8448 - volumes: - - ./config:/etc/dendrite - - ./media:/var/dendrite/media - depends_on: - - postgres - networks: - - internal - restart: unless-stopped - -networks: - internal: - attachable: true diff --git a/build/docker/docker-compose.yml b/build/docker/docker-compose.yml new file mode 100644 index 000000000..9397673f8 --- /dev/null +++ b/build/docker/docker-compose.yml @@ -0,0 +1,52 @@ +version: "3.4" + +services: + postgres: + hostname: postgres + image: postgres:15-alpine + restart: always + volumes: + # This will create a docker volume to persist the database files in. + # If you prefer those files to be outside of docker, you'll need to change this. + - dendrite_postgres_data:/var/lib/postgresql/data + environment: + POSTGRES_PASSWORD: itsasecret + POSTGRES_USER: dendrite + POSTGRES_DATABASE: dendrite + healthcheck: + test: ["CMD-SHELL", "pg_isready -U dendrite"] + interval: 5s + timeout: 5s + retries: 5 + networks: + - internal + + monolith: + hostname: monolith + image: matrixdotorg/dendrite-monolith:latest + ports: + - 8008:8008 + - 8448:8448 + volumes: + - ./config:/etc/dendrite + # The following volumes use docker volumes, change this + # if you prefer to have those files outside of docker. + - dendrite_media:/var/dendrite/media + - dendrite_jetstream:/var/dendrite/jetstream + - dendrite_search_index:/var/dendrite/searchindex + depends_on: + postgres: + condition: service_healthy + networks: + - internal + restart: unless-stopped + +networks: + internal: + attachable: true + +volumes: + dendrite_postgres_data: + dendrite_media: + dendrite_jetstream: + dendrite_search_index: \ No newline at end of file diff --git a/build/docker/postgres/create_db.sh b/build/docker/postgres/create_db.sh deleted file mode 100755 index 27d2a4df4..000000000 --- a/build/docker/postgres/create_db.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/sh - -for db in userapi_accounts mediaapi syncapi roomserver keyserver federationapi appservice mscs; do - createdb -U dendrite -O dendrite dendrite_$db -done diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 69a815321..1145cb12d 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -133,7 +133,11 @@ func TestPurgeRoom(t *testing.T) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} - defer close() + defer func() { + // give components the time to process purge requests + time.Sleep(time.Millisecond * 50) + close() + }() routers := httputil.NewRouters() cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 8dd662a1b..3d64454c4 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -123,7 +123,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De } } request := struct { - Password string `json:"password"` + Password string `json:"password"` + LogoutDevices bool `json:"logout_devices"` }{} if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ @@ -146,7 +147,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De Localpart: localpart, ServerName: serverName, Password: request.Password, - LogoutDevices: true, + LogoutDevices: request.LogoutDevices, } updateRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil { diff --git a/dendrite-sample.yaml b/dendrite-sample.yaml index 6b3ea74f2..96143d85f 100644 --- a/dendrite-sample.yaml +++ b/dendrite-sample.yaml @@ -69,8 +69,7 @@ global: # e.g. localhost:443 well_known_server_name: "" - # The server name to delegate client-server communications to, with optional port - # e.g. localhost:443 + # The base URL to delegate client-server communications to e.g. https://localhost well_known_client_name: "" # Lists of domains that the server will trust as identity servers to verify third diff --git a/docs/FAQ.md b/docs/FAQ.md index 200020726..757bf9625 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -24,7 +24,7 @@ No, although a good portion of the Matrix specification has been implemented. Mo Dendrite development is currently supported by a small team of developers and due to those limited resources, the majority of the effort is focused on getting Dendrite to be specification complete. If there are major features you're requesting (e.g. new administration endpoints), we'd like to strongly encourage you to join the community in supporting -the development efforts through [contributing](https://matrix-org.github.io/dendrite/development/contributing). +the development efforts through [contributing](../development/contributing). ## Is there a migration path from Synapse to Dendrite? @@ -103,7 +103,7 @@ This can be done by performing a room upgrade. Use the command `/upgraderoom " } diff --git a/docs/installation/11_optimisation.md b/docs/administration/5_optimisation.md similarity index 90% rename from docs/installation/11_optimisation.md rename to docs/administration/5_optimisation.md index 686ec2eb9..b327171eb 100644 --- a/docs/installation/11_optimisation.md +++ b/docs/administration/5_optimisation.md @@ -1,9 +1,9 @@ --- title: Optimise your installation -parent: Installation +parent: Administration has_toc: true -nav_order: 11 -permalink: /installation/start/optimisation +nav_order: 5 +permalink: /administration/optimisation --- # Optimise your installation @@ -36,11 +36,6 @@ connections it will open to the database. **If you are using the `global` database pool** then you only need to configure the `max_open_conns` setting once in the `global` section. -**If you are defining a `database` config per component** then you will need to ensure that -the **sum total** of all configured `max_open_conns` to a given database server do not exceed -the connection limit. If you configure a total that adds up to more connections than are available -then this will cause database queries to fail. - You may wish to raise the `max_connections` limit on your PostgreSQL server to accommodate additional connections, in which case you should also update the `max_open_conns` in your Dendrite configuration accordingly. However be aware that this is only advisable on particularly diff --git a/docs/administration/5_troubleshooting.md b/docs/administration/6_troubleshooting.md similarity index 88% rename from docs/administration/5_troubleshooting.md rename to docs/administration/6_troubleshooting.md index 8ba510ef6..5f11f9931 100644 --- a/docs/administration/5_troubleshooting.md +++ b/docs/administration/6_troubleshooting.md @@ -1,6 +1,7 @@ --- title: Troubleshooting parent: Administration +nav_order: 6 permalink: /administration/troubleshooting --- @@ -18,7 +19,7 @@ be clues in the logs. You can increase this log level to the more verbose `debug` level if necessary by adding this to the config and restarting Dendrite: -``` +```yaml logging: - type: std level: debug @@ -56,12 +57,7 @@ number of database connections does not exceed the maximum allowed by PostgreSQL Open your `postgresql.conf` configuration file and check the value of `max_connections` (which is typically `100` by default). Then open your `dendrite.yaml` configuration file -and ensure that: - -1. If you are using the `global.database` section, that `max_open_conns` does not exceed - that number; -2. If you are **not** using the `global.database` section, that the sum total of all - `max_open_conns` across all `database` blocks does not exceed that number. +and ensure that in the `global.database` section, `max_open_conns` does not exceed that number. ## 5. File descriptors @@ -77,7 +73,7 @@ If there aren't, you will see a log lines like this: level=warning msg="IMPORTANT: Process file descriptor limit is currently 65535, it is recommended to raise the limit for Dendrite to at least 65535 to avoid issues" ``` -Follow the [Optimisation](../installation/11_optimisation.md) instructions to correct the +Follow the [Optimisation](5_optimisation.md) instructions to correct the available number of file descriptors. ## 6. STUN/TURN Server tester diff --git a/docs/caddy/monolith/Caddyfile b/docs/caddy/Caddyfile similarity index 100% rename from docs/caddy/monolith/Caddyfile rename to docs/caddy/Caddyfile diff --git a/docs/caddy/polylith/Caddyfile b/docs/caddy/polylith/Caddyfile deleted file mode 100644 index c2d81b49b..000000000 --- a/docs/caddy/polylith/Caddyfile +++ /dev/null @@ -1,85 +0,0 @@ -# Sample Caddyfile for using Caddy in front of Dendrite - -# - -# Customize email address and domain names - -# Optional settings commented out - -# - -# BE SURE YOUR DOMAINS ARE POINTED AT YOUR SERVER FIRST - -# Documentation: - -# - -# Bonus tip: If your IP address changes, use Caddy's - -# dynamic DNS plugin to update your DNS records to - -# point to your new IP automatically - -# - -# - -# Global options block - -{ - # In case there is a problem with your certificates. - # email example@example.com - - # Turn off the admin endpoint if you don't need graceful config - # changes and/or are running untrusted code on your machine. - # admin off - - # Enable this if your clients don't send ServerName in TLS handshakes. - # default_sni example.com - - # Enable debug mode for verbose logging. - # debug - - # Use Let's Encrypt's staging endpoint for testing. - # acme_ca https://acme-staging-v02.api.letsencrypt.org/directory - - # If you're port-forwarding HTTP/HTTPS ports from 80/443 to something - # else, enable these and put the alternate port numbers here. - # http_port 8080 - # https_port 8443 -} - -# The server name of your matrix homeserver. This example shows - -# "well-known delegation" from the registered domain to a subdomain - -# which is only needed if your server_name doesn't match your Matrix - -# homeserver URL (i.e. you can show users a vanity domain that looks - -# nice and is easy to remember but still have your Matrix server on - -# its own subdomain or hosted service) - -example.com { - header /.well-known/matrix/*Content-Type application/json - header /.well-known/matrix/* Access-Control-Allow-Origin * - respond /.well-known/matrix/server `{"m.server": "matrix.example.com:443"}` - respond /.well-known/matrix/client `{"m.homeserver": {"base_url": "https://matrix.example.com"}}` -} - -# The actual domain name whereby your Matrix server is accessed - -matrix.example.com { - # Change the end of each reverse_proxy line to the correct - # address for your various services. - @sync_api { - path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ - } - reverse_proxy @sync_api sync_api:8073 - - reverse_proxy /_matrix/client* client_api:8071 - reverse_proxy /_matrix/federation* federation_api:8071 - reverse_proxy /_matrix/key* federation_api:8071 - reverse_proxy /_matrix/media* media_api:8071 -} diff --git a/docs/development/CONTRIBUTING.md b/docs/development/CONTRIBUTING.md index 2aec4c363..71e7516a2 100644 --- a/docs/development/CONTRIBUTING.md +++ b/docs/development/CONTRIBUTING.md @@ -1,6 +1,7 @@ --- title: Contributing parent: Development +nav_order: 1 permalink: /development/contributing --- diff --git a/docs/development/PROFILING.md b/docs/development/PROFILING.md index 57c37a900..dc4eca7b7 100644 --- a/docs/development/PROFILING.md +++ b/docs/development/PROFILING.md @@ -1,6 +1,7 @@ --- title: Profiling parent: Development +nav_order: 4 permalink: /development/profiling --- diff --git a/docs/development/coverage.md b/docs/development/coverage.md index c4a8a1174..1b15f71a2 100644 --- a/docs/development/coverage.md +++ b/docs/development/coverage.md @@ -1,78 +1,130 @@ --- title: Coverage parent: Development +nav_order: 3 permalink: /development/coverage --- -To generate a test coverage report for Sytest, a small patch needs to be applied to the Sytest repository to compile and use the instrumented binary: -```patch -diff --git a/lib/SyTest/Homeserver/Dendrite.pm b/lib/SyTest/Homeserver/Dendrite.pm -index 8f0e209c..ad057e52 100644 ---- a/lib/SyTest/Homeserver/Dendrite.pm -+++ b/lib/SyTest/Homeserver/Dendrite.pm -@@ -337,7 +337,7 @@ sub _start_monolith - - $output->diag( "Starting monolith server" ); - my @command = ( -- $self->{bindir} . '/dendrite', -+ $self->{bindir} . '/dendrite', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL", - '--config', $self->{paths}{config}, - '--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port, - '--https-bind-address', $self->{bind_host} . ':' . $self->secure_port, -diff --git a/scripts/dendrite_sytest.sh b/scripts/dendrite_sytest.sh -index f009332b..7ea79869 100755 ---- a/scripts/dendrite_sytest.sh -+++ b/scripts/dendrite_sytest.sh -@@ -34,7 +34,8 @@ export GOBIN=/tmp/bin - echo >&2 "--- Building dendrite from source" - cd /src - mkdir -p $GOBIN --go install -v ./cmd/dendrite -+# go install -v ./cmd/dendrite -+go test -c -cover -covermode=atomic -o $GOBIN/dendrite -coverpkg "github.com/matrix-org/..." ./cmd/dendrite - go install -v ./cmd/generate-keys - cd - - ``` +## Running unit tests with coverage enabled + +Running unit tests with coverage enabled can be done with the following commands, this will generate a `integrationcover.log` +```bash +go test -covermode=atomic -coverpkg=./... -coverprofile=integrationcover.log $(go list ./... | grep -v '/cmd/') +go tool cover -func=integrationcover.log +``` + +## Running Sytest with coverage enabled + +To run Sytest with coverage enabled: + +```bash +docker run --rm --name sytest -v "/Users/kegan/github/sytest:/sytest" \ + -v "/Users/kegan/github/dendrite:/src" -v "$(pwd)/sytest_logs:/logs" \ + -v "/Users/kegan/go/:/gopath" -e "POSTGRES=1" \ + -e "COVER=1" \ + matrixdotorg/sytest-dendrite:latest + +# to get a more accurate coverage you may also need to run Sytest using SQLite as the database: +docker run --rm --name sytest -v "/Users/kegan/github/sytest:/sytest" \ + -v "/Users/kegan/github/dendrite:/src" -v "$(pwd)/sytest_logs:/logs" \ + -v "/Users/kegan/go/:/gopath" \ + -e "COVER=1" \ + matrixdotorg/sytest-dendrite:latest +``` + +This will generate a folder `covdatafiles` in each server's directory, e.g `server-0/covdatafiles`. To parse them, +ensure your working directory is under the Dendrite repository then run: - Then run Sytest. This will generate a new file `integrationcover.log` in each server's directory e.g `server-0/integrationcover.log`. To parse it, - ensure your working directory is under the Dendrite repository then run: ```bash - go tool cover -func=/path/to/server-0/integrationcover.log + go tool covdata func -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" ``` which will produce an output like: ``` ... - github.com/matrix-org/util/json.go:83: NewJSONRequestHandler 100.0% -github.com/matrix-org/util/json.go:90: Protect 57.1% -github.com/matrix-org/util/json.go:110: RequestWithLogging 100.0% -github.com/matrix-org/util/json.go:132: MakeJSONAPI 70.0% -github.com/matrix-org/util/json.go:151: respond 61.5% -github.com/matrix-org/util/json.go:180: WithCORSOptions 0.0% -github.com/matrix-org/util/json.go:191: SetCORSHeaders 100.0% -github.com/matrix-org/util/json.go:202: RandomString 100.0% -github.com/matrix-org/util/json.go:210: init 100.0% -github.com/matrix-org/util/unique.go:13: Unique 91.7% -github.com/matrix-org/util/unique.go:48: SortAndUnique 100.0% -github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0% -total: (statements) 53.7% +github.com/matrix-org/util/json.go:132: MakeJSONAPI 70.0% +github.com/matrix-org/util/json.go:151: respond 84.6% +github.com/matrix-org/util/json.go:180: WithCORSOptions 0.0% +github.com/matrix-org/util/json.go:191: SetCORSHeaders 100.0% +github.com/matrix-org/util/json.go:202: RandomString 100.0% +github.com/matrix-org/util/json.go:210: init 100.0% +github.com/matrix-org/util/unique.go:13: Unique 91.7% +github.com/matrix-org/util/unique.go:48: SortAndUnique 100.0% +github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0% +total (statements) 64.0% ``` -The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in many different configurations, -which will never be tested in a single test run (e.g sqlite or postgres). To get a more accurate value, additional processing is required -to remove packages which will never be tested and extension MSCs: +(after running Sytest for Postgres _and_ SQLite) + +The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in different configurations, +which will never be tested in a single test run (e.g sqlite or postgres). To get a more accurate value, you'll need run Sytest for Postgres and SQLite (see commands above). +Additional processing is required also to remove packages which will never be tested and extension MSCs: + ```bash -# These commands are all similar but change which package paths are _removed_ from the output. +# If you executed both commands from above, you can get the total coverage using the following commands +go tool covdata textfmt -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" -o sytest.cov +grep -Ev 'relayapi|setup/mscs' sytest.cov > final.cov +go tool cover -func=final.cov -# For Postgres -go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|sqlite|setup/mscs|api_trace' > coverage.txt +# If you only executed the one for Postgres: +go tool covdata textfmt -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" -o sytest.cov +grep -Ev 'relayapi|sqlite|setup/mscs' sytest.cov > final.cov +go tool cover -func=final.cov -# For SQLite -go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|postgres|setup/mscs|api_trace' > coverage.txt +# If you only executed the one for SQLite: +go tool covdata textfmt -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" -o sytest.cov +grep -Ev 'relayapi|postgres|setup/mscs' sytest.cov > final.cov +go tool cover -func=final.cov ``` -A total value can then be calculated using: +## Getting coverage from Complement + +Getting the coverage for Complement runs is a bit more involved. + +First you'll need a docker image compatible with Complement, one can be built using ```bash -cat coverage.txt | awk -F '\t+' '{x = x + $3} END {print x/NR}' +docker build -t complement-dendrite -f build/scripts/Complement.Dockerfile . +``` +from within the Dendrite repository. + +Clone complement to a directory of your liking: +```bash +git clone https://github.com/matrix-org/complement.git +cd complement ``` +Next we'll need a script to execute after a test finishes, create a new file `posttest.sh`, make the file executable (`chmod +x posttest.sh`) +and add the following content: +```bash +#!/bin/bash -We currently do not have a way to combine Sytest/Complement/Unit Tests into a single coverage report. \ No newline at end of file +mkdir -p /tmp/Complement/logs/$2/$1/ +docker cp $1:/tmp/covdatafiles/. /tmp/Complement/logs/$2/$1/ +``` +This will copy the `covdatafiles` files from each container to something like +`/tmp/Complement/logs/TestLogin/94f9c428de95779d2b62a3ccd8eab9d5ddcf65cc259a40ece06bdc61687ffed3/`. (`$1` is the containerID, `$2` the test name) + +Now that we have set up everything we need, we can finally execute Complement: +```bash +COMPLEMENT_BASE_IMAGE=complement-dendrite \ +COMPLEMENT_SHARE_ENV_PREFIX=COMPLEMENT_DENDRITE_ \ +COMPLEMENT_DENDRITE_COVER=1 \ +COMPLEMENT_POST_TEST_SCRIPT=$(pwd)/posttest.sh \ + go test -tags dendrite_blacklist ./tests/... -count=1 -v -timeout=30m -failfast=false +``` + +Once this is done, you can copy the resulting `covdatafiles` files to your Dendrite repository for the next step. +```bash +cp -pr /tmp/Complement/logs PathToYourDendriteRepository +``` + +You can also run the following to get the coverage for Complement runs alone: +```bash +go tool covdata func -i="$(find /tmp/Complement -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" +``` + +## Combining the results of (almost) all runs + +Now that we have all our `covdatafiles` files within the Dendrite repository, you can now execute the following command, to get the coverage +overall (excluding unit tests): +```bash +go tool covdata func -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" +``` \ No newline at end of file diff --git a/docs/development/sytest.md b/docs/development/sytest.md index 4fae2ea3d..2f681f3e5 100644 --- a/docs/development/sytest.md +++ b/docs/development/sytest.md @@ -1,6 +1,7 @@ --- title: SyTest parent: Development +nav_order: 2 permalink: /development/sytest --- @@ -23,7 +24,7 @@ After running the tests, a script will print the tests you need to add to You should proceed after you see no build problems for dendrite after running: ```sh -./build.sh +go build -o bin/ ./cmd/... ``` If you are fixing an issue marked with @@ -61,6 +62,8 @@ When debugging, the following Docker `run` options may also be useful: * `-e "DENDRITE_TRACE_HTTP=1"`: Adds HTTP tracing to server logs. * `-e "DENDRITE_TRACE_INTERNAL=1"`: Adds roomserver internal API tracing to server logs. +* `-e "COVER=1"`: Run Sytest with an instrumented binary, producing a Go coverage file per server. +* `-e "RACE_DETECTION=1"`: Build the binaries with the `-race` flag (Note: This will significantly slow down test runs) The docker command also supports a single positional argument for the test file to run, so you can run a single `.pl` file rather than the whole test suite. For example: @@ -71,68 +74,3 @@ docker run --rm --name sytest -v "/Users/kegan/github/sytest:/sytest" -v "/Users/kegan/go/:/gopath" -e "POSTGRES=1" -e "DENDRITE_TRACE_HTTP=1" matrixdotorg/sytest-dendrite:latest tests/50federation/40devicelists.pl ``` - -### Manually Setting up SyTest - -**We advise AGAINST using manual SyTest setups.** - -If you don't want to use the Docker image, you can also run SyTest by hand. Make -sure you have Perl 5 or above, and get SyTest with: - -(Note that this guide assumes your SyTest checkout is next to your -`dendrite` checkout.) - -```sh -git clone -b develop https://github.com/matrix-org/sytest -cd sytest -./install-deps.pl -``` - -Set up the database: - -```sh -sudo -u postgres psql -c "CREATE USER dendrite PASSWORD 'itsasecret'" -sudo -u postgres psql -c "ALTER USER dendrite CREATEDB" -for i in dendrite0 dendrite1 sytest_template; do sudo -u postgres psql -c "CREATE DATABASE $i OWNER dendrite;"; done -mkdir -p "server-0" -cat > "server-0/database.yaml" << EOF -args: - user: dendrite - password: itsasecret - database: dendrite0 - host: 127.0.0.1 - sslmode: disable -type: pg -EOF -mkdir -p "server-1" -cat > "server-1/database.yaml" << EOF -args: - user: dendrite - password: itsasecret - database: dendrite1 - host: 127.0.0.1 - sslmode: disable -type: pg -EOF -``` - -Run the tests: - -```sh -POSTGRES=1 ./run-tests.pl -I Dendrite::Monolith -d ../dendrite/bin -W ../dendrite/sytest-whitelist -O tap --all | tee results.tap -``` - -where `tee` lets you see the results while they're being piped to the file, and -`POSTGRES=1` enables testing with PostgeSQL. If the `POSTGRES` environment -variable is not set or is set to 0, SyTest will fall back to SQLite 3. For more -flags and options, see . - -Once the tests are complete, run the helper script to see if you need to add -any newly passing test names to `sytest-whitelist` in the project's root -directory: - -```sh -../dendrite/show-expected-fail-tests.sh results.tap ../dendrite/sytest-whitelist ../dendrite/sytest-blacklist -``` - -If the script prints nothing/exits with 0, then you're good to go. diff --git a/docs/development/tracing/opentracing.md b/docs/development/tracing/opentracing.md deleted file mode 100644 index 8528c2ba3..000000000 --- a/docs/development/tracing/opentracing.md +++ /dev/null @@ -1,114 +0,0 @@ ---- -title: OpenTracing -has_children: true -parent: Development -permalink: /development/opentracing ---- - -# OpenTracing - -Dendrite extensively uses the [opentracing.io](http://opentracing.io) framework -to trace work across the different logical components. - -At its most basic opentracing tracks "spans" of work; recording start and end -times as well as any parent span that caused the piece of work. - -A typical example would be a new span being created on an incoming request that -finishes when the response is sent. When the code needs to hit out to a -different component a new span is created with the initial span as its parent. -This would end up looking roughly like: - -``` -Received request Sent response - |<───────────────────────────────────────>| - |<────────────────────>| - RPC call RPC call returns -``` - -This is useful to see where the time is being spent processing a request on a -component. However, opentracing allows tracking of spans across components. This -makes it possible to see exactly what work goes into processing a request: - -``` -Component 1 |<─────────────────── HTTP ────────────────────>| - |<──────────────── RPC ─────────────────>| -Component 2 |<─ SQL ─>| |<── RPC ───>| -Component 3 |<─ SQL ─>| -``` - -This is achieved by serializing span information during all communication -between components. For HTTP requests, this is achieved by the sender -serializing the span into a HTTP header, and the receiver deserializing the span -on receipt. (Generally a new span is then immediately created with the -deserialized span as the parent). - -A collection of spans that are related is called a trace. - -Spans are passed through the code via contexts, rather than manually. It is -therefore important that all spans that are created are immediately added to the -current context. Thankfully the opentracing library gives helper functions for -doing this: - -```golang -span, ctx := opentracing.StartSpanFromContext(ctx, spanName) -defer span.Finish() -``` - -This will create a new span, adding any span already in `ctx` as a parent to the -new span. - -Adding Information ------------------- - -Opentracing allows adding information to a trace via three mechanisms: - -- "tags" ─ A span can be tagged with a key/value pair. This is typically - information that relates to the span, e.g. for spans created for incoming HTTP - requests could include the request path and response codes as tags, spans for - SQL could include the query being executed. -- "logs" ─ Key/value pairs can be looged at a particular instance in a trace. - This can be useful to log e.g. any errors that happen. -- "baggage" ─ Arbitrary key/value pairs can be added to a span to which all - child spans have access. Baggage isn't saved and so isn't available when - inspecting the traces, but can be used to add context to logs or tags in child - spans. - -See -[specification.md](https://github.com/opentracing/specification/blob/master/specification.md) -for some of the common tags and log fields used. - -Span Relationships ------------------- - -Spans can be related to each other. The most common relation is `childOf`, which -indicates the child span somehow depends on the parent span ─ typically the -parent span cannot complete until all child spans are completed. - -A second relation type is `followsFrom`, where the parent has no dependence on -the child span. This usually indicates some sort of fire and forget behaviour, -e.g. adding a message to a pipeline or inserting into a kafka topic. - -Jaeger ------- - -Opentracing is just a framework. We use -[jaeger](https://github.com/jaegertracing/jaeger) as the actual implementation. - -Jaeger is responsible for recording, sending and saving traces, as well as -giving a UI for viewing and interacting with traces. - -To enable jaeger a `Tracer` object must be instansiated from the config (as well -as having a jaeger server running somewhere, usually locally). A `Tracer` does -several things: - -- Decides which traces to save and send to the server. There are multiple - schemes for doing this, with a simple example being to save a certain fraction - of traces. -- Communicating with the jaeger backend. If not explicitly specified uses the - default port on localhost. -- Associates a service name to all spans created by the tracer. This service - name equates to a logical component, e.g. spans created by clientapi will have - a different service name than ones created by the syncapi. Database access - will also typically use a different service name. - - This means that there is a tracer per service name/component. diff --git a/docs/development/tracing/setup.md b/docs/development/tracing/setup.md deleted file mode 100644 index cef1089e4..000000000 --- a/docs/development/tracing/setup.md +++ /dev/null @@ -1,57 +0,0 @@ ---- -title: Setup -parent: OpenTracing -grand_parent: Development -permalink: /development/opentracing/setup ---- - -# OpenTracing Setup - -Dendrite uses [Jaeger](https://www.jaegertracing.io/) for tracing between microservices. -Tracing shows the nesting of logical spans which provides visibility on how the microservices interact. -This document explains how to set up Jaeger locally on a single machine. - -## Set up the Jaeger backend - -The [easiest way](https://www.jaegertracing.io/docs/1.18/getting-started/) is to use the all-in-one Docker image: - -``` -$ docker run -d --name jaeger \ - -e COLLECTOR_ZIPKIN_HTTP_PORT=9411 \ - -p 5775:5775/udp \ - -p 6831:6831/udp \ - -p 6832:6832/udp \ - -p 5778:5778 \ - -p 16686:16686 \ - -p 14268:14268 \ - -p 14250:14250 \ - -p 9411:9411 \ - jaegertracing/all-in-one:1.18 -``` - -## Configuring Dendrite to talk to Jaeger - -Modify your config to look like: (this will send every single span to Jaeger which will be slow on large instances, but for local testing it's fine) - -``` -tracing: - enabled: true - jaeger: - serviceName: "dendrite" - disabled: false - rpc_metrics: true - tags: [] - sampler: - type: const - param: 1 -``` - -then run the monolith server: - -``` -./dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml -``` - -## Checking traces - -Visit to see traces under `DendriteMonolith`. diff --git a/docs/hiawatha/monolith-sample.conf b/docs/hiawatha/dendrite-sample.conf similarity index 100% rename from docs/hiawatha/monolith-sample.conf rename to docs/hiawatha/dendrite-sample.conf diff --git a/docs/hiawatha/polylith-sample.conf b/docs/hiawatha/polylith-sample.conf deleted file mode 100644 index eb1dd4f9a..000000000 --- a/docs/hiawatha/polylith-sample.conf +++ /dev/null @@ -1,35 +0,0 @@ -# Depending on which port is used for federation (.well-known/matrix/server or SRV record), -# ensure there's a binding for that port in the configuration. Replace "FEDPORT" with port -# number, (e.g. "8448"), and "IPV4" with your server's ipv4 address (separate binding for -# each ip address, e.g. if you use both ipv4 and ipv6 addresses). - -Binding { - Port = FEDPORT - Interface = IPV4 - TLScertFile = /path/to/fullchainandprivkey.pem -} - - -VirtualHost { - ... - # route requests to: - # /_matrix/client/.*/sync - # /_matrix/client/.*/user/{userId}/filter - # /_matrix/client/.*/user/{userId}/filter/{filterID} - # /_matrix/client/.*/keys/changes - # /_matrix/client/.*/rooms/{roomId}/messages - # /_matrix/client/.*/rooms/{roomId}/context/{eventID} - # /_matrix/client/.*/rooms/{roomId}/event/{eventID} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType} - # /_matrix/client/.*/rooms/{roomId}/members - # /_matrix/client/.*/rooms/{roomId}/joined_members - # to sync_api - ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ http://localhost:8073 600 - ReverseProxy = /_matrix/client http://localhost:8071 600 - ReverseProxy = /_matrix/federation http://localhost:8072 600 - ReverseProxy = /_matrix/key http://localhost:8072 600 - ReverseProxy = /_matrix/media http://localhost:8074 600 - ... -} diff --git a/docs/installation/1_planning.md b/docs/installation/1_planning.md index 36d90abda..354003aef 100644 --- a/docs/installation/1_planning.md +++ b/docs/installation/1_planning.md @@ -7,23 +7,13 @@ permalink: /installation/planning # Planning your installation -## Modes - -Dendrite consists of several components, each responsible for a different aspect of the Matrix protocol. -Users can run Dendrite in one of two modes which dictate how these components are executed and communicate. - -* **Monolith mode** runs all components in a single process. Components communicate through an internal NATS - server with generally low overhead. This mode dramatically simplifies deployment complexity and offers the - best balance between performance and resource usage for low-to-mid volume deployments. - - -## Databases +## Database Dendrite can run with either a PostgreSQL or a SQLite backend. There are considerable tradeoffs to consider: * **PostgreSQL**: Needs to run separately to Dendrite, needs to be installed and configured separately - and and will use more resources over all, but will be **considerably faster** than SQLite. PostgreSQL + and will use more resources over all, but will be **considerably faster** than SQLite. PostgreSQL has much better write concurrency which will allow Dendrite to process more tasks in parallel. This will be necessary for federated deployments to perform adequately. @@ -80,18 +70,17 @@ If using the PostgreSQL database engine, you should install PostgreSQL 12 or lat ### NATS Server Dendrite comes with a built-in [NATS Server](https://github.com/nats-io/nats-server) and -therefore does not need this to be manually installed. If you are planning a monolith installation, you -do not need to do anything. +therefore does not need this to be manually installed. ### Reverse proxy A reverse proxy such as [Caddy](https://caddyserver.com), [NGINX](https://www.nginx.com) or -[HAProxy](http://www.haproxy.org) is useful for deployments. Configuring those is not covered in this documentation, although sample configurations +[HAProxy](http://www.haproxy.org) is useful for deployments. Configuring this is not covered in this documentation, although sample configurations for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy) and [NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx) are provided. ### Windows -Finally, if you want to build Dendrite on Windows, you will need need `gcc` in the path. The best +Finally, if you want to build Dendrite on Windows, you will need `gcc` in the path. The best way to achieve this is by installing and building Dendrite under [MinGW-w64](https://www.mingw-w64.org/). diff --git a/docs/installation/2_domainname.md b/docs/installation/2_domainname.md index 545a2daf6..d86a664cb 100644 --- a/docs/installation/2_domainname.md +++ b/docs/installation/2_domainname.md @@ -20,7 +20,7 @@ Matrix servers usually discover each other when federating using the following m well-known file to connect to the remote homeserver; 2. If a DNS SRV delegation exists on `example.com`, use the IP address and port from the DNS SRV record to connect to the remote homeserver; -3. If neither well-known or DNS SRV delegation are configured, attempt to connect to the remote +3. If neither well-known nor DNS SRV delegation are configured, attempt to connect to the remote homeserver by connecting to `example.com` port TCP/8448 using HTTPS. The exact details of how server name resolution works can be found in diff --git a/docs/installation/5_install_monolith.md b/docs/installation/5_install_monolith.md deleted file mode 100644 index 901975a65..000000000 --- a/docs/installation/5_install_monolith.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -title: Installing as a monolith -parent: Installation -has_toc: true -nav_order: 5 -permalink: /installation/install/monolith ---- - -# Installing as a monolith - -You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: - -```sh -go install ./cmd/dendrite -``` - -Alternatively, you can specify a custom path for the binary to be written to using `go build`: - -```sh -go build -o /usr/local/bin/ ./cmd/dendrite -``` diff --git a/docs/installation/9_starting_monolith.md b/docs/installation/9_starting_monolith.md deleted file mode 100644 index d7e8c0b8b..000000000 --- a/docs/installation/9_starting_monolith.md +++ /dev/null @@ -1,42 +0,0 @@ ---- -title: Starting the monolith -parent: Installation -has_toc: true -nav_order: 9 -permalink: /installation/start/monolith ---- - -# Starting the monolith - -Once you have completed all of the preparation and installation steps, -you can start your Dendrite monolith deployment by starting `dendrite`: - -```bash -./dendrite -config /path/to/dendrite.yaml -``` - -By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses -or ports that Dendrite listens on, you can use the `-http-bind-address` and -`-https-bind-address` command line arguments: - -```bash -./dendrite -config /path/to/dendrite.yaml \ - -http-bind-address 1.2.3.4:12345 \ - -https-bind-address 1.2.3.4:54321 -``` - -## Running under systemd - -A common deployment pattern is to run the monolith under systemd. For this, you -will need to create a service unit file. An example service unit file is available -in the [GitHub repository](https://github.com/matrix-org/dendrite/blob/main/docs/systemd/monolith-example.service). - -Once you have installed the service unit, you can notify systemd, enable and start -the service: - -```bash -systemctl daemon-reload -systemctl enable dendrite -systemctl start dendrite -journalctl -fu dendrite -``` diff --git a/docs/installation/docker.md b/docs/installation/docker.md new file mode 100644 index 000000000..1ecc7c6ee --- /dev/null +++ b/docs/installation/docker.md @@ -0,0 +1,11 @@ +--- +title: Docker +parent: Installation +has_children: true +nav_order: 4 +permalink: /docker +--- + +# Installation using Docker + +This section contains documentation how to install Dendrite using Docker diff --git a/docs/installation/docker/1_docker.md b/docs/installation/docker/1_docker.md new file mode 100644 index 000000000..1fe792636 --- /dev/null +++ b/docs/installation/docker/1_docker.md @@ -0,0 +1,57 @@ +--- +title: Installation +parent: Docker +grand_parent: Installation +has_toc: true +nav_order: 1 +permalink: /installation/docker/install +--- + +# Installing Dendrite using Docker Compose + +Dendrite provides an [example](https://github.com/matrix-org/dendrite/blob/main/build/docker/docker-compose.yml) +Docker compose file, which needs some preparation to start successfully. +Please note that this compose file only has Postgres as a dependency, and you need to configure +a [reverse proxy](../planning#reverse-proxy). + +## Preparations + +### Generate a private key + +First we'll generate private key, which is used to sign events, the following will create one in `./config`: + +```bash +mkdir -p ./config +docker run --rm --entrypoint="/usr/bin/generate-keys" \ + -v $(pwd)/config:/mnt \ + matrixdotorg/dendrite-monolith:latest \ + -private-key /mnt/matrix_key.pem +``` +(**NOTE**: This only needs to be executed **once**, as you otherwise overwrite the key) + +### Generate a config + +Similar to the command above, we can generate a config to be used, which will use the correct paths +as specified in the example docker-compose file. Change `server` to your domain and `db` according to your changes +to the docker-compose file (`services.postgres.environment` values): + +```bash +mkdir -p ./config +docker run --rm --entrypoint="/bin/sh" \ + -v $(pwd)/config:/mnt \ + matrixdotorg/dendrite-monolith:latest \ + -c "/usr/bin/generate-config \ + -dir /var/dendrite/ \ + -db postgres://dendrite:itsasecret@postgres/dendrite?sslmode=disable \ + -server YourDomainHere > /mnt/dendrite.yaml" +``` + +You can then change `config/dendrite.yaml` to your liking. + +## Starting Dendrite + +Once you're done changing the config, you can now start up Dendrite with + +```bash +docker-compose -f docker-compose.yml up +``` diff --git a/docs/installation/helm.md b/docs/installation/helm.md new file mode 100644 index 000000000..dd20e0261 --- /dev/null +++ b/docs/installation/helm.md @@ -0,0 +1,11 @@ +--- +title: Helm +parent: Installation +has_children: true +nav_order: 3 +permalink: /helm +--- + +# Helm + +This section contains documentation how to use [Helm](https://helm.sh/) to install Dendrite on a [Kubernetes](https://kubernetes.io/) cluster. diff --git a/docs/installation/helm/1_helm.md b/docs/installation/helm/1_helm.md new file mode 100644 index 000000000..00fe4fdca --- /dev/null +++ b/docs/installation/helm/1_helm.md @@ -0,0 +1,58 @@ +--- +title: Installation +parent: Helm +grand_parent: Installation +has_toc: true +nav_order: 1 +permalink: /installation/helm/install +--- + +# Installing Dendrite using Helm + +To install Dendrite using the Helm chart, you first have to add the repository using the following commands: + +```bash +helm repo add dendrite https://matrix-org.github.io/dendrite/ +helm repo update +``` + +Next you'll need to create a `values.yaml` file and configure it to your liking. All possible values can be found +[here](https://github.com/matrix-org/dendrite/blob/main/helm/dendrite/values.yaml), but at least you need to configure +a `server_name`, otherwise the chart will complain about it: + +```yaml +dendrite_config: + global: + server_name: "localhost" +``` + +If you are going to use an existing Postgres database, you'll also need to configure this connection: + +```yaml +dendrite_config: + global: + database: + connection_string: "postgresql://PostgresUser:PostgresPassword@PostgresHostName/DendriteDatabaseName" + max_open_conns: 90 + max_idle_conns: 5 + conn_max_lifetime: -1 +``` + +## Installing with PostgreSQL + +The chart comes with a dependency on Postgres, which can be installed alongside Dendrite, this needs to be enabled in +the `values.yaml`: + +```yaml +postgresql: + enabled: true # this installs Postgres + primary: + persistence: + size: 1Gi # defines the size for $PGDATA + +dendrite_config: + global: + server_name: "localhost" +``` + +Using this option, the `database.connection_string` will be set for you automatically. \ No newline at end of file diff --git a/docs/installation/manual.md b/docs/installation/manual.md new file mode 100644 index 000000000..3ab1fd627 --- /dev/null +++ b/docs/installation/manual.md @@ -0,0 +1,11 @@ +--- +title: Manual +parent: Installation +has_children: true +nav_order: 5 +permalink: /manual +--- + +# Manual Installation + +This section contains documentation how to manually install Dendrite diff --git a/docs/installation/3_build.md b/docs/installation/manual/1_build.md similarity index 53% rename from docs/installation/3_build.md rename to docs/installation/manual/1_build.md index 824c81d37..73a626882 100644 --- a/docs/installation/3_build.md +++ b/docs/installation/manual/1_build.md @@ -1,31 +1,26 @@ --- -title: Building Dendrite -parent: Installation +title: Building/Installing Dendrite +parent: Manual +grand_parent: Installation has_toc: true -nav_order: 3 -permalink: /installation/build +nav_order: 1 +permalink: /installation/manual/build --- # Build all Dendrite commands Dendrite has numerous utility commands in addition to the actual server binaries. -Build them all from the root of the source repo with `build.sh` (Linux/Mac): +Build them all from the root of the source repo with: ```sh -./build.sh -``` - -or `build.cmd` (Windows): - -```powershell -build.cmd +go build -o bin/ ./cmd/... ``` The resulting binaries will be placed in the `bin` subfolder. -# Installing as a monolith +# Installing Dendrite -You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: +You can install the Dendrite binary into `$GOPATH/bin` by using `go install`: ```sh go install ./cmd/dendrite diff --git a/docs/installation/4_database.md b/docs/installation/manual/2_database.md similarity index 57% rename from docs/installation/4_database.md rename to docs/installation/manual/2_database.md index d64ee6615..1be602c66 100644 --- a/docs/installation/4_database.md +++ b/docs/installation/manual/2_database.md @@ -1,8 +1,10 @@ --- title: Preparing database storage parent: Installation -nav_order: 3 -permalink: /installation/database +nav_order: 2 +parent: Manual +grand_parent: Installation +permalink: /installation/manual/database --- # Preparing database storage @@ -13,31 +15,22 @@ may need to perform some manual steps outlined below. ## PostgreSQL Dendrite can automatically populate the database with the relevant tables and indexes, but -it is not capable of creating the databases themselves. You will need to create the databases +it is not capable of creating the database itself. You will need to create the database manually. -The databases **must** be created with UTF-8 encoding configured or you will likely run into problems +The database **must** be created with UTF-8 encoding configured, or you will likely run into problems with your Dendrite deployment. -At this point, you can choose to either use a single database for all Dendrite components, -or you can run each component with its own separate database: +You will need to create a single PostgreSQL database. Deployments +can use a single global connection pool, which makes updating the configuration file much easier. +Only one database connection string to manage and likely simpler to back up the database. All +components will be sharing the same database resources (CPU, RAM, storage). -* **Single database**: You will need to create a single PostgreSQL database. Monolith deployments - can use a single global connection pool, which makes updating the configuration file much easier. - Only one database connection string to manage and likely simpler to back up the database. All - components will be sharing the same database resources (CPU, RAM, storage). - -* **Separate databases**: You will need to create a separate PostgreSQL database for each - component. You will need to configure each component that has storage in the Dendrite - configuration file with its own connection parameters. Allows running a different database engine - for each component on a different machine if needs be, each with their own CPU, RAM and storage — - almost certainly overkill unless you are running a very large Dendrite deployment. - -For either configuration, you will want to: +You will most likely want to: 1. Configure a role (with a username and password) which Dendrite can use to connect to the database; -2. Create the database(s) themselves, ensuring that the Dendrite role has privileges over them. +2. Create the database itself, ensuring that the Dendrite role has privileges over them. As Dendrite will create and manage the database tables, indexes and sequences by itself, the Dendrite role must have suitable privileges over the database. @@ -71,27 +64,6 @@ Create the database itself, using the `dendrite` role from above: sudo -u postgres createdb -O dendrite -E UTF-8 dendrite ``` -### Multiple database creation - -The following eight components require a database. In this example they will be named: - -| Appservice API | `dendrite_appservice` | -| Federation API | `dendrite_federationapi` | -| Media API | `dendrite_mediaapi` | -| MSCs | `dendrite_mscs` | -| Roomserver | `dendrite_roomserver` | -| Sync API | `dendrite_syncapi` | -| Key server | `dendrite_keyserver` | -| User API | `dendrite_userapi` | - -... therefore you will need to create eight different databases: - -```bash -for i in appservice federationapi mediaapi mscs roomserver syncapi keyserver userapi; do - sudo -u postgres createdb -O dendrite -E UTF-8 dendrite_$i -done -``` - ## SQLite **WARNING:** The Dendrite SQLite backend is slower, less reliable and not recommended for diff --git a/docs/installation/7_configuration.md b/docs/installation/manual/3_configuration.md similarity index 67% rename from docs/installation/7_configuration.md rename to docs/installation/manual/3_configuration.md index 0cc67b156..a9dd81c87 100644 --- a/docs/installation/7_configuration.md +++ b/docs/installation/manual/3_configuration.md @@ -1,8 +1,9 @@ --- title: Configuring Dendrite -parent: Installation -nav_order: 7 -permalink: /installation/configuration +parent: Manual +grand_parent: Installation +nav_order: 3 +permalink: /installation/manual/configuration --- # Configuring Dendrite @@ -20,7 +21,7 @@ sections: First of all, you will need to configure the server name of your Matrix homeserver. This must match the domain name that you have selected whilst [configuring the domain -name delegation](domainname). +name delegation](domainname#delegation). In the `global` section, set the `server_name` to your delegated domain name: @@ -44,7 +45,7 @@ global: ## JetStream configuration -Monolith deployments can use the built-in NATS Server rather than running a standalone +Dendrite deployments can use the built-in NATS Server rather than running a standalone server. If you want to use a standalone NATS Server anyway, you can also configure that too. ### Built-in NATS Server @@ -56,7 +57,6 @@ configured and set a `storage_path` to a persistent folder on the filesystem: global: # ... jetstream: - in_memory: false storage_path: /path/to/storage/folder topic_prefix: Dendrite ``` @@ -79,22 +79,17 @@ You do not need to configure the `storage_path` when using a standalone NATS Ser In the case that you are connecting to a multi-node NATS cluster, you can configure more than one address in the `addresses` field. -## Database connections +## Database connection using a global connection pool -Configuring database connections varies based on the [database configuration](database) -that you chose. - -### Global connection pool - -If you want to use a single connection pool to a single PostgreSQL database, then you must -uncomment and configure the `database` section within the `global` section: +If you want to use a single connection pool to a single PostgreSQL database, +then you must uncomment and configure the `database` section within the `global` section: ```yaml global: # ... database: connection_string: postgres://user:pass@hostname/database?sslmode=disable - max_open_conns: 100 + max_open_conns: 90 max_idle_conns: 5 conn_max_lifetime: -1 ``` @@ -104,42 +99,13 @@ configuration file, e.g. under the `app_service_api`, `federation_api`, `key_ser `media_api`, `mscs`, `relay_api`, `room_server`, `sync_api` and `user_api` blocks, otherwise these will override the `global` database configuration. -### Per-component connections (all other configurations) - -If you are are using SQLite databases or separate PostgreSQL -databases per component, then you must instead configure the `database` sections under each -of the component blocks ,e.g. under the `app_service_api`, `federation_api`, `key_server`, -`media_api`, `mscs`, `relay_api`, `room_server`, `sync_api` and `user_api` blocks. - -For example, with PostgreSQL: - -```yaml -room_server: - # ... - database: - connection_string: postgres://user:pass@hostname/dendrite_component?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 -``` - -... or with SQLite: - -```yaml -room_server: - # ... - database: - connection_string: file:roomserver.db - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 -``` - ## Full-text search -Dendrite supports experimental full-text indexing using [Bleve](https://github.com/blevesearch/bleve). It is configured in the `sync_api` section as follows. +Dendrite supports full-text indexing using [Bleve](https://github.com/blevesearch/bleve). It is configured in the `sync_api` section as follows. -Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expectations. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). +Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, +to ensure the returned results match the expectations. A full list of possible languages +can be found [here](https://github.com/matrix-org/dendrite/blob/5b73592f5a4dddf64184fcbe33f4c1835c656480/internal/fulltext/bleve.go#L25-L46). ```yaml sync_api: diff --git a/docs/installation/8_signingkey.md b/docs/installation/manual/4_signingkey.md similarity index 92% rename from docs/installation/8_signingkey.md rename to docs/installation/manual/4_signingkey.md index 323759a88..bd9c242ab 100644 --- a/docs/installation/8_signingkey.md +++ b/docs/installation/manual/4_signingkey.md @@ -1,8 +1,9 @@ --- title: Generating signing keys -parent: Installation -nav_order: 8 -permalink: /installation/signingkeys +parent: Manual +grand_parent: Installation +nav_order: 4 +permalink: /installation/manual/signingkeys --- # Generating signing keys @@ -11,7 +12,7 @@ All Matrix homeservers require a signing private key, which will be used to auth federation requests and events. The `generate-keys` utility can be used to generate a private key. Assuming that Dendrite was -built using `build.sh`, you should find the `generate-keys` utility in the `bin` folder. +built using `go build -o bin/ ./cmd/...`, you should find the `generate-keys` utility in the `bin` folder. To generate a Matrix signing private key: diff --git a/docs/installation/manual/5_starting_dendrite.md b/docs/installation/manual/5_starting_dendrite.md new file mode 100644 index 000000000..d13504372 --- /dev/null +++ b/docs/installation/manual/5_starting_dendrite.md @@ -0,0 +1,26 @@ +--- +title: Starting Dendrite +parent: Manual +grand_parent: Installation +nav_order: 5 +permalink: /installation/manual/start +--- + +# Starting Dendrite + +Once you have completed all preparation and installation steps, +you can start your Dendrite deployment by executing the `dendrite` binary: + +```bash +./dendrite -config /path/to/dendrite.yaml +``` + +By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses +or ports that Dendrite listens on, you can use the `-http-bind-address` and +`-https-bind-address` command line arguments: + +```bash +./dendrite -config /path/to/dendrite.yaml \ + -http-bind-address 1.2.3.4:12345 \ + -https-bind-address 1.2.3.4:54321 +``` diff --git a/docs/nginx/monolith-sample.conf b/docs/nginx/dendrite-sample.conf similarity index 100% rename from docs/nginx/monolith-sample.conf rename to docs/nginx/dendrite-sample.conf diff --git a/docs/nginx/polylith-sample.conf b/docs/nginx/polylith-sample.conf deleted file mode 100644 index 0ad24509a..000000000 --- a/docs/nginx/polylith-sample.conf +++ /dev/null @@ -1,58 +0,0 @@ -server { - listen 443 ssl; # IPv4 - listen [::]:443 ssl; # IPv6 - server_name my.hostname.com; - - ssl_certificate /path/to/fullchain.pem; - ssl_certificate_key /path/to/privkey.pem; - ssl_dhparam /path/to/ssl-dhparams.pem; - - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_read_timeout 600; - - location /.well-known/matrix/server { - return 200 '{ "m.server": "my.hostname.com:443" }'; - } - - location /.well-known/matrix/client { - # If your sever_name here doesn't match your matrix homeserver URL - # (e.g. hostname.com as server_name and matrix.hostname.com as homeserver URL) - # add_header Access-Control-Allow-Origin '*'; - return 200 '{ "m.homeserver": { "base_url": "https://my.hostname.com" } }'; - } - - # route requests to: - # /_matrix/client/.*/sync - # /_matrix/client/.*/user/{userId}/filter - # /_matrix/client/.*/user/{userId}/filter/{filterID} - # /_matrix/client/.*/keys/changes - # /_matrix/client/.*/rooms/{roomId}/messages - # /_matrix/client/.*/rooms/{roomId}/context/{eventID} - # /_matrix/client/.*/rooms/{roomId}/event/{eventID} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType} - # /_matrix/client/.*/rooms/{roomId}/members - # /_matrix/client/.*/rooms/{roomId}/joined_members - # to sync_api - location ~ /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ { - proxy_pass http://sync_api:8073; - } - - location /_matrix/client { - proxy_pass http://client_api:8071; - } - - location /_matrix/federation { - proxy_pass http://federation_api:8072; - } - - location /_matrix/key { - proxy_pass http://federation_api:8072; - } - - location /_matrix/media { - proxy_pass http://media_api:8074; - } -} diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service deleted file mode 100644 index 8a948a3fa..000000000 --- a/docs/systemd/monolith-example.service +++ /dev/null @@ -1,19 +0,0 @@ -[Unit] -Description=Dendrite (Matrix Homeserver) -After=syslog.target -After=network.target -After=postgresql.service - -[Service] -Environment=GODEBUG=madvdontneed=1 -RestartSec=2s -Type=simple -User=dendrite -Group=dendrite -WorkingDirectory=/opt/dendrite/ -ExecStart=/opt/dendrite/bin/dendrite -Restart=always -LimitNOFILE=65535 - -[Install] -WantedBy=multi-user.target diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 17296febc..8d21b7829 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -136,7 +136,7 @@ func (r *Admin) PerformAdminEvacuateRoom( inputReq := &api.InputRoomEventsRequest{ InputRoomEvents: inputEvents, - Asynchronous: true, + Asynchronous: false, } inputRes := &api.InputRoomEventsResponse{} r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) @@ -200,18 +200,24 @@ func (r *Admin) PerformAdminPurgeRoom( } // Evacuate the room before purging it from the database - if _, err := r.PerformAdminEvacuateRoom(ctx, roomID); err != nil { + evacAffected, err := r.PerformAdminEvacuateRoom(ctx, roomID) + if err != nil { logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to evacuate room before purging") return err } + logrus.WithFields(logrus.Fields{ + "room_id": roomID, + "evacuated_users": len(evacAffected), + }).Warn("Evacuated room, purging room from roomserver now") + logrus.WithField("room_id", roomID).Warn("Purging room from roomserver") if err := r.DB.PurgeRoom(ctx, roomID); err != nil { logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to purge room from roomserver") return err } - logrus.WithField("room_id", roomID).Warn("Room purged from roomserver") + logrus.WithField("room_id", roomID).Warn("Room purged from roomserver, informing other components") return r.Inputer.OutputProducer.ProduceRoomEvents(roomID, []api.OutputEvent{ { diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index b49c2f701..132bd80c8 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -130,7 +130,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( if pos == 0 { pos = r.High() } - return data, pos, nil + return data, pos, rows.Err() } func (s *accountDataStatements) SelectMaxAccountDataID( From 3dcca4017cb919fb249784d9cf9b83ea60a77f15 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 30 May 2023 15:27:11 +0200 Subject: [PATCH 10/35] Fix potential state reset when trying to join a room (#3040) When trying to join a room in short sequence, it is possible that a state reset occurs. This fixes it by using `singleflight`. --- clientapi/routing/routing.go | 42 +++++++++++++++++++++++++++++------- go.mod | 1 + go.sum | 1 + 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 2a2fa6655..d3f19cae1 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -20,14 +20,16 @@ import ( "strings" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/setup/base" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" + + "github.com/matrix-org/dendrite/setup/base" + userapi "github.com/matrix-org/dendrite/userapi/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" @@ -84,6 +86,14 @@ func Setup( unstableFeatures["org.matrix."+msc] = true } + // singleflight protects /join endpoints from being invoked + // multiple times from the same user and room, otherwise + // a state reset can occur. This also avoids unneeded + // state calculations. + // TODO: actually fix this in the roomserver, as there are + // possibly other ways that can result in a stat reset. + sf := singleflight.Group{} + if cfg.Matrix.WellKnownClientName != "" { logrus.Infof("Setting m.homeserver base_url as %s at /.well-known/matrix/client", cfg.Matrix.WellKnownClientName) wkMux.Handle("/client", httputil.MakeExternalAPI("wellknown", func(r *http.Request) util.JSONResponse { @@ -264,9 +274,17 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return JoinRoomByIDOrAlias( - req, device, rsAPI, userAPI, vars["roomIDOrAlias"], - ) + // Only execute a join for roomIDOrAlias and UserID once. If there is a join in progress + // it waits for it to complete and returns that result for subsequent requests. + resp, _, _ := sf.Do(vars["roomIDOrAlias"]+device.UserID, func() (any, error) { + return JoinRoomByIDOrAlias( + req, device, rsAPI, userAPI, vars["roomIDOrAlias"], + ), nil + }) + // once all joins are processed, drop them from the cache. Further requests + // will be processed as usual. + sf.Forget(vars["roomIDOrAlias"] + device.UserID) + return resp.(util.JSONResponse) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) @@ -300,9 +318,17 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return JoinRoomByIDOrAlias( - req, device, rsAPI, userAPI, vars["roomID"], - ) + // Only execute a join for roomID and UserID once. If there is a join in progress + // it waits for it to complete and returns that result for subsequent requests. + resp, _, _ := sf.Do(vars["roomID"]+device.UserID, func() (any, error) { + return JoinRoomByIDOrAlias( + req, device, rsAPI, userAPI, vars["roomID"], + ), nil + }) + // once all joins are processed, drop them from the cache. Further requests + // will be processed as usual. + sf.Forget(vars["roomID"] + device.UserID) + return resp.(util.JSONResponse) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/leave", diff --git a/go.mod b/go.mod index 16e5adc8c..360ddf5b2 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( golang.org/x/crypto v0.9.0 golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e + golang.org/x/sync v0.1.0 golang.org/x/term v0.8.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 98e7e839d..4a6054af8 100644 --- a/go.sum +++ b/go.sum @@ -614,6 +614,7 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= From 61341aca500ec4d87e5b6d4c3f965c3836d6e6d6 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 30 May 2023 18:05:48 +0200 Subject: [PATCH 11/35] Add tests for the `UpDropEventReferenceSHAPrevEvents` migration (#3087) ... as they could fail if there are duplicate events in `roomserver_previous_events`. This fixes the migration by trying to combine the `event_nids` if possible (same room) as mentioned by @kegsay in https://github.com/matrix-org/dendrite/pull/3083#discussion_r1195508963 --- .../20230516154000_drop_reference_sha.go | 86 ++++++++++++++++--- .../20230516154000_drop_reference_sha_test.go | 60 +++++++++++++ .../20230516154000_drop_reference_sha.go | 78 ++++++++++++++++- .../20230516154000_drop_reference_sha_test.go | 59 +++++++++++++ 4 files changed, 271 insertions(+), 12 deletions(-) create mode 100644 roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha_test.go create mode 100644 roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha_test.go diff --git a/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go index c19577713..1b1dd44d3 100644 --- a/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go +++ b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go @@ -18,19 +18,14 @@ import ( "context" "database/sql" "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/util" ) func UpDropEventReferenceSHAEvents(ctx context.Context, tx *sql.Tx) error { - var count int - err := tx.QueryRowContext(ctx, `SELECT count(*) FROM roomserver_events GROUP BY event_id HAVING count(event_id) > 1`). - Scan(&count) - if err != nil && err != sql.ErrNoRows { - return fmt.Errorf("failed to query duplicate event ids") - } - if count > 0 { - return fmt.Errorf("unable to drop column, as there are duplicate event ids") - } - _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_events DROP COLUMN IF EXISTS reference_sha256;`) + _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_events DROP COLUMN IF EXISTS reference_sha256;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -46,9 +41,80 @@ func UpDropEventReferenceSHAPrevEvents(ctx context.Context, tx *sql.Tx) error { if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } + + // figure out if there are duplicates + dupeRows, err := tx.QueryContext(ctx, `SELECT previous_event_id FROM roomserver_previous_events GROUP BY previous_event_id HAVING count(previous_event_id) > 1`) + if err != nil { + return fmt.Errorf("failed to query duplicate event ids") + } + defer internal.CloseAndLogIfError(ctx, dupeRows, "failed to close rows") + + var prevEvents []string + var prevEventID string + for dupeRows.Next() { + if err = dupeRows.Scan(&prevEventID); err != nil { + return err + } + prevEvents = append(prevEvents, prevEventID) + } + if dupeRows.Err() != nil { + return dupeRows.Err() + } + + // if we found duplicates, check if we can combine them, e.g. they are in the same room + for _, dupeID := range prevEvents { + var dupeNIDsRows *sql.Rows + dupeNIDsRows, err = tx.QueryContext(ctx, `SELECT event_nids FROM roomserver_previous_events WHERE previous_event_id = $1`, dupeID) + if err != nil { + return fmt.Errorf("failed to query duplicate event ids") + } + defer internal.CloseAndLogIfError(ctx, dupeNIDsRows, "failed to close rows") + var dupeNIDs []int64 + for dupeNIDsRows.Next() { + var nids pq.Int64Array + if err = dupeNIDsRows.Scan(&nids); err != nil { + return err + } + dupeNIDs = append(dupeNIDs, nids...) + } + + if dupeNIDsRows.Err() != nil { + return dupeNIDsRows.Err() + } + // dedupe NIDs + dupeNIDs = dupeNIDs[:util.SortAndUnique(nids(dupeNIDs))] + // now that we have all NIDs, check which room they belong to + var roomCount int + err = tx.QueryRowContext(ctx, `SELECT count(distinct room_nid) FROM roomserver_events WHERE event_nid = ANY($1)`, pq.Array(dupeNIDs)).Scan(&roomCount) + if err != nil { + return err + } + // if the events are from different rooms, that's bad and we can't continue + if roomCount > 1 { + return fmt.Errorf("detected events (%v) referenced for different rooms (%v)", dupeNIDs, roomCount) + } + // otherwise delete the dupes + _, err = tx.ExecContext(ctx, "DELETE FROM roomserver_previous_events WHERE previous_event_id = $1", dupeID) + if err != nil { + return fmt.Errorf("unable to delete duplicates: %w", err) + } + + // insert combined values + _, err = tx.ExecContext(ctx, "INSERT INTO roomserver_previous_events (previous_event_id, event_nids) VALUES ($1, $2)", dupeID, pq.Array(dupeNIDs)) + if err != nil { + return fmt.Errorf("unable to insert new event NIDs: %w", err) + } + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_previous_events ADD CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id);`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } return nil } + +type nids []int64 + +func (s nids) Len() int { return len(s) } +func (s nids) Less(i, j int) bool { return s[i] < s[j] } +func (s nids) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha_test.go b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha_test.go new file mode 100644 index 000000000..c79daac5f --- /dev/null +++ b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha_test.go @@ -0,0 +1,60 @@ +package deltas + +import ( + "testing" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/stretchr/testify/assert" +) + +func TestUpDropEventReferenceSHAPrevEvents(t *testing.T) { + + cfg, ctx, close := testrig.CreateConfig(t, test.DBTypePostgres) + defer close() + + db, err := sqlutil.Open(&cfg.Global.DatabaseOptions, sqlutil.NewDummyWriter()) + assert.Nil(t, err) + assert.NotNil(t, db) + defer db.Close() + + // create the table in the old layout + _, err = db.ExecContext(ctx.Context(), ` +CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + event_nids BIGINT[] NOT NULL, + previous_reference_sha256 BYTEA NOT NULL, + CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id, previous_reference_sha256) +);`) + assert.Nil(t, err) + + // create the events table as well, slimmed down with one eventNID + _, err = db.ExecContext(ctx.Context(), ` +CREATE SEQUENCE IF NOT EXISTS roomserver_event_nid_seq; +CREATE TABLE IF NOT EXISTS roomserver_events ( + event_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_event_nid_seq'), + room_nid BIGINT NOT NULL +); + +INSERT INTO roomserver_events (event_nid, room_nid) VALUES (1, 1) +`) + assert.Nil(t, err) + + // insert duplicate prev events with different event_nids + stmt, err := db.PrepareContext(ctx.Context(), `INSERT INTO roomserver_previous_events (previous_event_id, event_nids, previous_reference_sha256) VALUES ($1, $2, $3)`) + assert.Nil(t, err) + assert.NotNil(t, stmt) + _, err = stmt.ExecContext(ctx.Context(), "1", pq.Array([]int64{1, 2}), "a") + assert.Nil(t, err) + _, err = stmt.ExecContext(ctx.Context(), "1", pq.Array([]int64{1, 2, 3}), "b") + assert.Nil(t, err) + // execute the migration + txn, err := db.Begin() + assert.Nil(t, err) + assert.NotNil(t, txn) + defer txn.Rollback() + err = UpDropEventReferenceSHAPrevEvents(ctx.Context(), txn) + assert.NoError(t, err) +} diff --git a/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go index 452d72ace..515bccc37 100644 --- a/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go +++ b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go @@ -18,6 +18,10 @@ import ( "context" "database/sql" "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/util" ) func UpDropEventReferenceSHA(ctx context.Context, tx *sql.Tx) error { @@ -52,8 +56,72 @@ func UpDropEventReferenceSHAPrevEvents(ctx context.Context, tx *sql.Tx) error { return fmt.Errorf("tx.ExecContext: %w", err) } + // figure out if there are duplicates + dupeRows, err := tx.QueryContext(ctx, `SELECT previous_event_id FROM _roomserver_previous_events GROUP BY previous_event_id HAVING count(previous_event_id) > 1`) + if err != nil { + return fmt.Errorf("failed to query duplicate event ids") + } + defer internal.CloseAndLogIfError(ctx, dupeRows, "failed to close rows") + + var prevEvents []string + var prevEventID string + for dupeRows.Next() { + if err = dupeRows.Scan(&prevEventID); err != nil { + return err + } + prevEvents = append(prevEvents, prevEventID) + } + if dupeRows.Err() != nil { + return dupeRows.Err() + } + + // if we found duplicates, check if we can combine them, e.g. they are in the same room + for _, dupeID := range prevEvents { + var dupeNIDsRows *sql.Rows + dupeNIDsRows, err = tx.QueryContext(ctx, `SELECT event_nids FROM _roomserver_previous_events WHERE previous_event_id = $1`, dupeID) + if err != nil { + return fmt.Errorf("failed to query duplicate event ids") + } + defer internal.CloseAndLogIfError(ctx, dupeNIDsRows, "failed to close rows") + var dupeNIDs []int64 + for dupeNIDsRows.Next() { + var nids pq.Int64Array + if err = dupeNIDsRows.Scan(&nids); err != nil { + return err + } + dupeNIDs = append(dupeNIDs, nids...) + } + + if dupeNIDsRows.Err() != nil { + return dupeNIDsRows.Err() + } + // dedupe NIDs + dupeNIDs = dupeNIDs[:util.SortAndUnique(nids(dupeNIDs))] + // now that we have all NIDs, check which room they belong to + var roomCount int + err = tx.QueryRowContext(ctx, `SELECT count(distinct room_nid) FROM roomserver_events WHERE event_nid IN ($1)`, pq.Array(dupeNIDs)).Scan(&roomCount) + if err != nil { + return err + } + // if the events are from different rooms, that's bad and we can't continue + if roomCount > 1 { + return fmt.Errorf("detected events (%v) referenced for different rooms (%v)", dupeNIDs, roomCount) + } + // otherwise delete the dupes + _, err = tx.ExecContext(ctx, "DELETE FROM _roomserver_previous_events WHERE previous_event_id = $1", dupeID) + if err != nil { + return fmt.Errorf("unable to delete duplicates: %w", err) + } + + // insert combined values + _, err = tx.ExecContext(ctx, "INSERT INTO _roomserver_previous_events (previous_event_id, event_nids) VALUES ($1, $2)", dupeID, pq.Array(dupeNIDs)) + if err != nil { + return fmt.Errorf("unable to insert new event NIDs: %w", err) + } + } + // move data - if _, err := tx.ExecContext(ctx, ` + if _, err = tx.ExecContext(ctx, ` INSERT INTO roomserver_previous_events ( previous_event_id, event_nids @@ -64,9 +132,15 @@ INSERT return fmt.Errorf("tx.ExecContext: %w", err) } // drop old table - _, err := tx.ExecContext(ctx, `DROP TABLE _roomserver_previous_events;`) + _, err = tx.ExecContext(ctx, `DROP TABLE _roomserver_previous_events;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } return nil } + +type nids []int64 + +func (s nids) Len() int { return len(s) } +func (s nids) Less(i, j int) bool { return s[i] < s[j] } +func (s nids) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha_test.go b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha_test.go new file mode 100644 index 000000000..547d9703b --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha_test.go @@ -0,0 +1,59 @@ +package deltas + +import ( + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/stretchr/testify/assert" +) + +func TestUpDropEventReferenceSHAPrevEvents(t *testing.T) { + + cfg, ctx, close := testrig.CreateConfig(t, test.DBTypeSQLite) + defer close() + + db, err := sqlutil.Open(&cfg.RoomServer.Database, sqlutil.NewExclusiveWriter()) + assert.Nil(t, err) + assert.NotNil(t, db) + defer db.Close() + + // create the table in the old layout + _, err = db.ExecContext(ctx.Context(), ` + CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + previous_reference_sha256 BLOB, + event_nids TEXT NOT NULL, + UNIQUE (previous_event_id, previous_reference_sha256) + );`) + assert.Nil(t, err) + + // create the events table as well, slimmed down with one eventNID + _, err = db.ExecContext(ctx.Context(), ` + CREATE TABLE IF NOT EXISTS roomserver_events ( + event_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL +); + +INSERT INTO roomserver_events (event_nid, room_nid) VALUES (1, 1) +`) + assert.Nil(t, err) + + // insert duplicate prev events with different event_nids + stmt, err := db.PrepareContext(ctx.Context(), `INSERT INTO roomserver_previous_events (previous_event_id, event_nids, previous_reference_sha256) VALUES ($1, $2, $3)`) + assert.Nil(t, err) + assert.NotNil(t, stmt) + _, err = stmt.ExecContext(ctx.Context(), "1", "{1,2}", "a") + assert.Nil(t, err) + _, err = stmt.ExecContext(ctx.Context(), "1", "{1,2,3}", "b") + assert.Nil(t, err) + + // execute the migration + txn, err := db.Begin() + assert.Nil(t, err) + assert.NotNil(t, txn) + err = UpDropEventReferenceSHAPrevEvents(ctx.Context(), txn) + defer txn.Rollback() + assert.NoError(t, err) +} From cbdc601f1b6d1c2a648b69ff44b3a49916f4d31a Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 31 May 2023 15:27:08 +0000 Subject: [PATCH 12/35] Move CreateRoom logic to Roomserver (#3093) Move create room logic over to roomserver. --- clientapi/routing/createroom.go | 550 +++--------------- clientapi/routing/joinroom_test.go | 5 +- clientapi/routing/membership.go | 59 +- clientapi/routing/profile.go | 2 +- clientapi/routing/redaction.go | 2 +- clientapi/routing/sendevent.go | 2 +- clientapi/routing/server_notices.go | 2 +- clientapi/threepid/invites.go | 2 +- cmd/dendrite-upgrade-tests/tests.go | 5 +- federationapi/routing/join.go | 2 +- federationapi/routing/leave.go | 2 +- go.mod | 2 +- go.sum | 4 +- internal/eventutil/events.go | 7 +- roomserver/api/api.go | 2 + roomserver/api/perform.go | 24 + roomserver/internal/alias.go | 2 +- roomserver/internal/api.go | 13 + roomserver/internal/input/input_events.go | 2 +- roomserver/internal/perform/perform_admin.go | 4 +- .../internal/perform/perform_create_room.go | 498 ++++++++++++++++ roomserver/internal/perform/perform_join.go | 2 +- roomserver/internal/perform/perform_leave.go | 2 +- .../internal/perform/perform_upgrade.go | 39 +- 24 files changed, 684 insertions(+), 550 deletions(-) create mode 100644 roomserver/internal/perform/perform_create_room.go diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 7a7a85e85..aaa305f06 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -22,17 +22,13 @@ import ( "strings" "time" - "github.com/getsentry/sentry-go" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/types" roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -41,32 +37,19 @@ import ( // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom type createRoomRequest struct { - Invite []string `json:"invite"` - Name string `json:"name"` - Visibility string `json:"visibility"` - Topic string `json:"topic"` - Preset string `json:"preset"` - CreationContent json.RawMessage `json:"creation_content"` - InitialState []fledglingEvent `json:"initial_state"` - RoomAliasName string `json:"room_alias_name"` - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"` - IsDirect bool `json:"is_direct"` + Invite []string `json:"invite"` + Name string `json:"name"` + Visibility string `json:"visibility"` + Topic string `json:"topic"` + Preset string `json:"preset"` + CreationContent json.RawMessage `json:"creation_content"` + InitialState []gomatrixserverlib.FledglingEvent `json:"initial_state"` + RoomAliasName string `json:"room_alias_name"` + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"` + IsDirect bool `json:"is_direct"` } -const ( - presetPrivateChat = "private_chat" - presetTrustedPrivateChat = "trusted_private_chat" - presetPublicChat = "public_chat" -) - -const ( - historyVisibilityShared = "shared" - // TODO: These should be implemented once history visibility is implemented - // historyVisibilityWorldReadable = "world_readable" - // historyVisibilityInvited = "invited" -) - func (r createRoomRequest) Validate() *util.JSONResponse { whitespace := "\t\n\x0b\x0c\r " // https://docs.python.org/2/library/string.html#string.whitespace // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/handlers/room.py#L81 @@ -78,12 +61,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { } } for _, userID := range r.Invite { - // TODO: We should put user ID parsing code into gomatrixserverlib and use that instead - // (see https://github.com/matrix-org/gomatrixserverlib/blob/3394e7c7003312043208aa73727d2256eea3d1f6/eventcontent.go#L347 ) - // It should be a struct (with pointers into a single string to avoid copying) and - // we should update all refs to use UserID types rather than strings. - // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/types.py#L92 - if _, _, err := gomatrixserverlib.SplitID('@', userID); err != nil { + if _, err := spec.NewUserID(userID, true); err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("user id must be in the form @localpart:domain"), @@ -91,7 +69,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { } } switch r.Preset { - case presetPrivateChat, presetTrustedPrivateChat, presetPublicChat, "": + case spec.PresetPrivateChat, spec.PresetTrustedPrivateChat, spec.PresetPublicChat, "": default: return &util.JSONResponse{ Code: http.StatusBadRequest, @@ -129,13 +107,6 @@ type createRoomResponse struct { RoomAlias string `json:"room_alias,omitempty"` // in synapse not spec } -// fledglingEvent is a helper representation of an event used when creating many events in succession. -type fledglingEvent struct { - Type string `json:"type"` - StateKey string `json:"state_key"` - Content interface{} `json:"content"` -} - // CreateRoom implements /createRoom func CreateRoom( req *http.Request, device *api.Device, @@ -143,12 +114,12 @@ func CreateRoom( profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { - var r createRoomRequest - resErr := httputil.UnmarshalJSONRequest(req, &r) + var createRequest createRoomRequest + resErr := httputil.UnmarshalJSONRequest(req, &createRequest) if resErr != nil { return *resErr } - if resErr = r.Validate(); resErr != nil { + if resErr = createRequest.Validate(); resErr != nil { return *resErr } evTime, err := httputil.ParseTSParam(req) @@ -158,46 +129,52 @@ func CreateRoom( JSON: spec.InvalidParam(err.Error()), } } - return createRoom(req.Context(), r, device, cfg, profileAPI, rsAPI, asAPI, evTime) + return createRoom(req.Context(), createRequest, device, cfg, profileAPI, rsAPI, asAPI, evTime) } // createRoom implements /createRoom -// nolint: gocyclo func createRoom( ctx context.Context, - r createRoomRequest, device *api.Device, + // TODO: remove dependency on createRoomRequest + createRequest createRoomRequest, device *api.Device, cfg *config.ClientAPI, profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) util.JSONResponse { - _, userDomain, err := gomatrixserverlib.SplitID('@', device.UserID) + userID, err := spec.NewUserID(device.UserID, true) if err != nil { - util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + util.GetLogger(ctx).WithError(err).Error("invalid userID") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, } } - if !cfg.Matrix.IsLocalServerName(userDomain) { + if !cfg.Matrix.IsLocalServerName(userID.Domain()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: spec.Forbidden(fmt.Sprintf("User domain %q not configured locally", userDomain)), + JSON: spec.Forbidden(fmt.Sprintf("User domain %q not configured locally", userID.Domain())), } } - // TODO (#267): Check room ID doesn't clash with an existing one, and we - // probably shouldn't be using pseudo-random strings, maybe GUIDs? - roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) - logger := util.GetLogger(ctx) - userID := device.UserID + + // TODO: Check room ID doesn't clash with an existing one, and we + // probably shouldn't be using pseudo-random strings, maybe GUIDs? + roomID, err := spec.NewRoomID(fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain())) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("invalid roomID") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } // Clobber keys: creator, room_version roomVersion := roomserverVersion.DefaultRoomVersion() - if r.RoomVersion != "" { - candidateVersion := gomatrixserverlib.RoomVersion(r.RoomVersion) + if createRequest.RoomVersion != "" { + candidateVersion := gomatrixserverlib.RoomVersion(createRequest.RoomVersion) _, roomVersionError := roomserverVersion.SupportedRoomVersion(candidateVersion) if roomVersionError != nil { return util.JSONResponse{ @@ -208,17 +185,13 @@ func createRoom( roomVersion = candidateVersion } - // TODO: visibility/presets/raw initial state - // TODO: Create room alias association - // Make sure this doesn't fall into an application service's namespace though! - logger.WithFields(log.Fields{ - "userID": userID, - "roomID": roomID, + "userID": userID.String(), + "roomID": roomID.String(), "roomVersion": roomVersion, }).Info("Creating new room") - profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI) + profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID.String(), asAPI, profileAPI) if err != nil { util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") return util.JSONResponse{ @@ -227,427 +200,38 @@ func createRoom( } } - createContent := map[string]interface{}{} - if len(r.CreationContent) > 0 { - if err = json.Unmarshal(r.CreationContent, &createContent); err != nil { - util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("invalid create content"), - } - } + userDisplayName := profile.DisplayName + userAvatarURL := profile.AvatarURL + + keyID := cfg.Matrix.KeyID + privateKey := cfg.Matrix.PrivateKey + + req := roomserverAPI.PerformCreateRoomRequest{ + InvitedUsers: createRequest.Invite, + RoomName: createRequest.Name, + Visibility: createRequest.Visibility, + Topic: createRequest.Topic, + StatePreset: createRequest.Preset, + CreationContent: createRequest.CreationContent, + InitialState: createRequest.InitialState, + RoomAliasName: createRequest.RoomAliasName, + RoomVersion: roomVersion, + PowerLevelContentOverride: createRequest.PowerLevelContentOverride, + IsDirect: createRequest.IsDirect, + + UserDisplayName: userDisplayName, + UserAvatarURL: userAvatarURL, + KeyID: keyID, + PrivateKey: privateKey, + EventTime: evTime, } - createContent["creator"] = userID - createContent["room_version"] = roomVersion - powerLevelContent := eventutil.InitialPowerLevelsContent(userID) - joinRuleContent := gomatrixserverlib.JoinRuleContent{ - JoinRule: spec.Invite, - } - historyVisibilityContent := gomatrixserverlib.HistoryVisibilityContent{ - HistoryVisibility: historyVisibilityShared, - } - - if r.PowerLevelContentOverride != nil { - // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults - err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("malformed power_level_content_override"), - } - } - } - - var guestsCanJoin bool - switch r.Preset { - case presetPrivateChat: - joinRuleContent.JoinRule = spec.Invite - historyVisibilityContent.HistoryVisibility = historyVisibilityShared - guestsCanJoin = true - case presetTrustedPrivateChat: - joinRuleContent.JoinRule = spec.Invite - historyVisibilityContent.HistoryVisibility = historyVisibilityShared - for _, invitee := range r.Invite { - powerLevelContent.Users[invitee] = 100 - } - guestsCanJoin = true - case presetPublicChat: - joinRuleContent.JoinRule = spec.Public - historyVisibilityContent.HistoryVisibility = historyVisibilityShared - } - - createEvent := fledglingEvent{ - Type: spec.MRoomCreate, - Content: createContent, - } - powerLevelEvent := fledglingEvent{ - Type: spec.MRoomPowerLevels, - Content: powerLevelContent, - } - joinRuleEvent := fledglingEvent{ - Type: spec.MRoomJoinRules, - Content: joinRuleContent, - } - historyVisibilityEvent := fledglingEvent{ - Type: spec.MRoomHistoryVisibility, - Content: historyVisibilityContent, - } - membershipEvent := fledglingEvent{ - Type: spec.MRoomMember, - StateKey: userID, - Content: gomatrixserverlib.MemberContent{ - Membership: spec.Join, - DisplayName: profile.DisplayName, - AvatarURL: profile.AvatarURL, - }, - } - - var nameEvent *fledglingEvent - var topicEvent *fledglingEvent - var guestAccessEvent *fledglingEvent - var aliasEvent *fledglingEvent - - if r.Name != "" { - nameEvent = &fledglingEvent{ - Type: spec.MRoomName, - Content: eventutil.NameContent{ - Name: r.Name, - }, - } - } - - if r.Topic != "" { - topicEvent = &fledglingEvent{ - Type: spec.MRoomTopic, - Content: eventutil.TopicContent{ - Topic: r.Topic, - }, - } - } - - if guestsCanJoin { - guestAccessEvent = &fledglingEvent{ - Type: spec.MRoomGuestAccess, - Content: eventutil.GuestAccessContent{ - GuestAccess: "can_join", - }, - } - } - - var roomAlias string - if r.RoomAliasName != "" { - roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, userDomain) - // check it's free TODO: This races but is better than nothing - hasAliasReq := roomserverAPI.GetRoomIDForAliasRequest{ - Alias: roomAlias, - IncludeAppservices: false, - } - - var aliasResp roomserverAPI.GetRoomIDForAliasResponse - err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - if aliasResp.RoomID != "" { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.RoomInUse("Room ID already exists."), - } - } - - aliasEvent = &fledglingEvent{ - Type: spec.MRoomCanonicalAlias, - Content: eventutil.CanonicalAlias{ - Alias: roomAlias, - }, - } - } - - var initialStateEvents []fledglingEvent - for i := range r.InitialState { - if r.InitialState[i].StateKey != "" { - initialStateEvents = append(initialStateEvents, r.InitialState[i]) - continue - } - - switch r.InitialState[i].Type { - case spec.MRoomCreate: - continue - - case spec.MRoomPowerLevels: - powerLevelEvent = r.InitialState[i] - - case spec.MRoomJoinRules: - joinRuleEvent = r.InitialState[i] - - case spec.MRoomHistoryVisibility: - historyVisibilityEvent = r.InitialState[i] - - case spec.MRoomGuestAccess: - guestAccessEvent = &r.InitialState[i] - - case spec.MRoomName: - nameEvent = &r.InitialState[i] - - case spec.MRoomTopic: - topicEvent = &r.InitialState[i] - - default: - initialStateEvents = append(initialStateEvents, r.InitialState[i]) - } - } - - // send events into the room in order of: - // 1- m.room.create - // 2- room creator join member - // 3- m.room.power_levels - // 4- m.room.join_rules - // 5- m.room.history_visibility - // 6- m.room.canonical_alias (opt) - // 7- m.room.guest_access (opt) - // 8- other initial state items - // 9- m.room.name (opt) - // 10- m.room.topic (opt) - // 11- invite events (opt) - with is_direct flag if applicable TODO - // 12- 3pid invite events (opt) TODO - // This differs from Synapse slightly. Synapse would vary the ordering of 3-7 - // depending on if those events were in "initial_state" or not. This made it - // harder to reason about, hence sticking to a strict static ordering. - // TODO: Synapse has txn/token ID on each event. Do we need to do this here? - eventsToMake := []fledglingEvent{ - createEvent, membershipEvent, powerLevelEvent, joinRuleEvent, historyVisibilityEvent, - } - if guestAccessEvent != nil { - eventsToMake = append(eventsToMake, *guestAccessEvent) - } - eventsToMake = append(eventsToMake, initialStateEvents...) - if nameEvent != nil { - eventsToMake = append(eventsToMake, *nameEvent) - } - if topicEvent != nil { - eventsToMake = append(eventsToMake, *topicEvent) - } - if aliasEvent != nil { - // TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room. - // This means we might fail creating the alias but say the canonical alias is something that doesn't exist. - eventsToMake = append(eventsToMake, *aliasEvent) - } - - // TODO: invite events - // TODO: 3pid invite events - - verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("unknown room version"), - } - } - - var builtEvents []*types.HeaderedEvent - authEvents := gomatrixserverlib.NewAuthEvents(nil) - for i, e := range eventsToMake { - depth := i + 1 // depth starts at 1 - - builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ - Sender: userID, - RoomID: roomID, - Type: e.Type, - StateKey: &e.StateKey, - Depth: int64(depth), - }) - err = builder.SetContent(e.Content) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - if i > 0 { - builder.PrevEvents = []string{builtEvents[i-1].EventID()} - } - var ev gomatrixserverlib.PDU - if err = builder.AddAuthEvents(&authEvents); err != nil { - util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - ev, err = builder.Build(evTime, userDomain, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("buildEvent failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - - if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { - util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - - // Add the event to the list of auth events - builtEvents = append(builtEvents, &types.HeaderedEvent{PDU: ev}) - err = authEvents.AddEvent(ev) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - } - - inputs := make([]roomserverAPI.InputRoomEvent, 0, len(builtEvents)) - for _, event := range builtEvents { - inputs = append(inputs, roomserverAPI.InputRoomEvent{ - Kind: roomserverAPI.KindNew, - Event: event, - Origin: userDomain, - SendAsServer: roomserverAPI.DoNotSendToOtherServers, - }) - } - if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil { - util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - - // TODO(#269): Reserve room alias while we create the room. This stops us - // from creating the room but still failing due to the alias having already - // been taken. - if roomAlias != "" { - aliasReq := roomserverAPI.SetRoomAliasRequest{ - Alias: roomAlias, - RoomID: roomID, - UserID: userID, - } - - var aliasResp roomserverAPI.SetRoomAliasResponse - err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - - if aliasResp.AliasExists { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.RoomInUse("Room alias already exists."), - } - } - } - - // If this is a direct message then we should invite the participants. - if len(r.Invite) > 0 { - // Build some stripped state for the invite. - var globalStrippedState []fclient.InviteV2StrippedState - for _, event := range builtEvents { - // Chosen events from the spec: - // https://spec.matrix.org/v1.3/client-server-api/#stripped-state - switch event.Type() { - case spec.MRoomCreate: - fallthrough - case spec.MRoomName: - fallthrough - case spec.MRoomAvatar: - fallthrough - case spec.MRoomTopic: - fallthrough - case spec.MRoomCanonicalAlias: - fallthrough - case spec.MRoomEncryption: - fallthrough - case spec.MRoomMember: - fallthrough - case spec.MRoomJoinRules: - ev := event.PDU - globalStrippedState = append( - globalStrippedState, - fclient.NewInviteV2StrippedState(ev), - ) - } - } - - // Process the invites. - var inviteEvent *types.HeaderedEvent - for _, invitee := range r.Invite { - // Build the invite event. - inviteEvent, err = buildMembershipEvent( - ctx, invitee, "", profileAPI, device, spec.Invite, - roomID, r.IsDirect, cfg, evTime, rsAPI, asAPI, - ) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") - continue - } - inviteStrippedState := append( - globalStrippedState, - fclient.NewInviteV2StrippedState(inviteEvent.PDU), - ) - // Send the invite event to the roomserver. - event := inviteEvent - err = rsAPI.PerformInvite(ctx, &roomserverAPI.PerformInviteRequest{ - Event: event, - InviteRoomState: inviteStrippedState, - RoomVersion: event.Version(), - SendAsServer: string(userDomain), - }) - switch e := err.(type) { - case roomserverAPI.ErrInvalidID: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown(e.Error()), - } - case roomserverAPI.ErrNotAllowed: - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden(e.Error()), - } - case nil: - default: - util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") - sentry.CaptureException(err) - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - } - } - - if r.Visibility == spec.Public { - // expose this room in the published room list - if err = rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ - RoomID: roomID, - Visibility: spec.Public, - }); err != nil { - util.GetLogger(ctx).WithError(err).Error("failed to publish room") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } + roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req) + if createRes != nil { + return *createRes } response := createRoomResponse{ - RoomID: roomID, + RoomID: roomID.String(), RoomAlias: roomAlias, } diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index 4b67b09f0..0ddff8a95 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -11,6 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/roomserver" @@ -63,7 +64,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { IsDirect: true, Topic: "testing", Visibility: "public", - Preset: presetPublicChat, + Preset: spec.PresetPublicChat, RoomAliasName: "alias", Invite: []string{bob.ID}, }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) @@ -78,7 +79,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { IsDirect: true, Topic: "testing", Visibility: "public", - Preset: presetPublicChat, + Preset: spec.PresetPublicChat, Invite: []string{charlie.ID}, }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 4f2a0e394..0fe0a4ade 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -16,12 +16,14 @@ package routing import ( "context" + "crypto/ed25519" "fmt" "net/http" "time" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" @@ -308,6 +310,41 @@ func sendInvite( }, nil } +func buildMembershipEventDirect( + ctx context.Context, + targetUserID, reason string, userDisplayName, userAvatarURL string, + sender string, senderDomain spec.ServerName, + membership, roomID string, isDirect bool, + keyID gomatrixserverlib.KeyID, privateKey ed25519.PrivateKey, evTime time.Time, + rsAPI roomserverAPI.ClientRoomserverAPI, +) (*types.HeaderedEvent, error) { + proto := gomatrixserverlib.ProtoEvent{ + Sender: sender, + RoomID: roomID, + Type: "m.room.member", + StateKey: &targetUserID, + } + + content := gomatrixserverlib.MemberContent{ + Membership: membership, + DisplayName: userDisplayName, + AvatarURL: userAvatarURL, + Reason: reason, + IsDirect: isDirect, + } + + if err := proto.SetContent(content); err != nil { + return nil, err + } + + identity := &fclient.SigningIdentity{ + ServerName: senderDomain, + KeyID: keyID, + PrivateKey: privateKey, + } + return eventutil.QueryAndBuildEvent(ctx, &proto, identity, evTime, rsAPI, nil) +} + func buildMembershipEvent( ctx context.Context, targetUserID, reason string, profileAPI userapi.ClientUserAPI, @@ -321,31 +358,13 @@ func buildMembershipEvent( return nil, err } - proto := gomatrixserverlib.ProtoEvent{ - Sender: device.UserID, - RoomID: roomID, - Type: "m.room.member", - StateKey: &targetUserID, - } - - content := gomatrixserverlib.MemberContent{ - Membership: membership, - DisplayName: profile.DisplayName, - AvatarURL: profile.AvatarURL, - Reason: reason, - IsDirect: isDirect, - } - - if err = proto.SetContent(content); err != nil { - return nil, err - } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) if err != nil { return nil, err } - return eventutil.QueryAndBuildEvent(ctx, &proto, cfg.Matrix, identity, evTime, rsAPI, nil) + return buildMembershipEventDirect(ctx, targetUserID, reason, profile.DisplayName, profile.AvatarURL, + device.UserID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) } // loadProfile lookups the profile of a given user from the database and returns diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 76129f0a8..2c9d0cbbe 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -387,7 +387,7 @@ func buildMembershipEvents( return nil, err } - event, err := eventutil.QueryAndBuildEvent(ctx, &proto, cfg.Matrix, identity, evTime, rsAPI, nil) + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index ed70e5c5c..883126423 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -137,7 +137,7 @@ func SendRedaction( } var queryRes roomserverAPI.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &proto, identity, time.Now(), rsAPI, &queryRes) if errors.Is(err, eventutil.ErrRoomNoExists{}) { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index bc14642f8..1a2e25c9d 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -293,7 +293,7 @@ func generateSendEvent( } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(ctx, &proto, cfg.Matrix, identity, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, evTime, rsAPI, &queryRes) switch specificErr := err.(type) { case nil: case eventutil.ErrRoomNoExists: diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index ad50cc80b..06714ed1f 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -155,7 +155,7 @@ func SendServerNotice( Invite: []string{r.UserID}, Name: cfgNotices.RoomName, Visibility: "private", - Preset: presetPrivateChat, + Preset: spec.PresetPrivateChat, CreationContent: cc, RoomVersion: roomVersion, PowerLevelContentOverride: pl, diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index c296939d5..9f4f62e43 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -380,7 +380,7 @@ func emit3PIDInviteEvent( } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.QueryAndBuildEvent(ctx, proto, cfg.Matrix, identity, evTime, rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(ctx, proto, identity, evTime, rsAPI, &queryRes) if err != nil { return err } diff --git a/cmd/dendrite-upgrade-tests/tests.go b/cmd/dendrite-upgrade-tests/tests.go index 03438bd4d..692ab34ef 100644 --- a/cmd/dendrite-upgrade-tests/tests.go +++ b/cmd/dendrite-upgrade-tests/tests.go @@ -9,6 +9,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const userPassword = "this_is_a_long_password" @@ -56,7 +57,7 @@ func runTests(baseURL string, v *semver.Version) error { // create DM room, join it and exchange messages createRoomResp, err := users[0].client.CreateRoom(&gomatrix.ReqCreateRoom{ - Preset: "trusted_private_chat", + Preset: spec.PresetTrustedPrivateChat, Invite: []string{users[1].userID}, IsDirect: true, }) @@ -98,7 +99,7 @@ func runTests(baseURL string, v *semver.Version) error { publicRoomID := "" createRoomResp, err = users[0].client.CreateRoom(&gomatrix.ReqCreateRoom{ RoomAliasName: "global", - Preset: "public_chat", + Preset: spec.PresetPublicChat, }) if err != nil { // this is okay and expected if the room already exists and the aliases clash // try to join it diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 4cbfc5e87..03d3309ae 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -121,7 +121,7 @@ func MakeJoin( queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: roomVersion, } - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) switch e := err.(type) { case nil: case eventutil.ErrRoomNoExists: diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 3e576e09c..a767168d8 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -66,7 +66,7 @@ func MakeLeave( } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) switch e := err.(type) { case nil: case eventutil.ErrRoomNoExists: diff --git a/go.mod b/go.mod index 360ddf5b2..0e979de6f 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230524095531-95ba6c68efb6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230531143710-c681a0658246 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 4a6054af8..8baa50e85 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230524095531-95ba6c68efb6 h1:FQpdh/KGCCQJytz4GAdG6pbx3DJ1HNzdKFc/BCZ0hP0= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230524095531-95ba6c68efb6/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230531143710-c681a0658246 h1:1sYXx7p9BIf0R7OIV/TZg3SCvNehEQPCKNqwV1ONfwU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230531143710-c681a0658246/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index ca052c310..0f73db2d5 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" @@ -51,7 +50,7 @@ func (e ErrRoomNoExists) Unwrap() error { // Returns an error if something else went wrong func QueryAndBuildEvent( ctx context.Context, - proto *gomatrixserverlib.ProtoEvent, cfg *config.Global, + proto *gomatrixserverlib.ProtoEvent, identity *fclient.SigningIdentity, evTime time.Time, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*types.HeaderedEvent, error) { @@ -64,14 +63,14 @@ func QueryAndBuildEvent( // This can pass through a ErrRoomNoExists to the caller return nil, err } - return BuildEvent(ctx, proto, cfg, identity, evTime, eventsNeeded, queryRes) + return BuildEvent(ctx, proto, identity, evTime, eventsNeeded, queryRes) } // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // provided. func BuildEvent( ctx context.Context, - proto *gomatrixserverlib.ProtoEvent, cfg *config.Global, + proto *gomatrixserverlib.ProtoEvent, identity *fclient.SigningIdentity, evTime time.Time, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, ) (*types.HeaderedEvent, error) { diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 213e16e5d..571aa40b3 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -5,6 +5,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" @@ -169,6 +170,7 @@ type ClientRoomserverAPI interface { GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error + PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) // PerformRoomUpgrade upgrades a room to a newer version PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index c6e5f5a1c..8d9742c69 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -1,6 +1,10 @@ package api import ( + "crypto/ed25519" + "encoding/json" + "time" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" @@ -8,6 +12,26 @@ import ( "github.com/matrix-org/util" ) +type PerformCreateRoomRequest struct { + InvitedUsers []string + RoomName string + Visibility string + Topic string + StatePreset string + CreationContent json.RawMessage + InitialState []gomatrixserverlib.FledglingEvent + RoomAliasName string + RoomVersion gomatrixserverlib.RoomVersion + PowerLevelContentOverride json.RawMessage + IsDirect bool + + UserDisplayName string + UserAvatarURL string + KeyID gomatrixserverlib.KeyID + PrivateKey ed25519.PrivateKey + EventTime time.Time +} + type PerformJoinRequest struct { RoomIDOrAlias string `json:"room_id_or_alias"` UserID string `json:"user_id"` diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 4d2de9a5a..52b90cf4e 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -208,7 +208,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - newEvent, err := eventutil.BuildEvent(ctx, proto, &r.Cfg.Global, identity, time.Now(), &eventsNeeded, stateRes) + newEvent, err := eventutil.BuildEvent(ctx, proto, identity, time.Now(), &eventsNeeded, stateRes) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 81904c8b8..f61f89183 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -6,6 +6,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -40,6 +41,7 @@ type RoomserverInternalAPI struct { *perform.Forgetter *perform.Upgrader *perform.Admin + *perform.Creator ProcessContext *process.ProcessContext DB storage.Database Cfg *config.Dendrite @@ -191,6 +193,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Queryer: r.Queryer, Leaver: r.Leaver, } + r.Creator = &perform.Creator{ + DB: r.DB, + Cfg: &r.Cfg.RoomServer, + RSAPI: r, + } if err := r.Inputer.Start(); err != nil { logrus.WithError(err).Panic("failed to start roomserver input API") @@ -206,6 +213,12 @@ func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalA r.asAPI = asAPI } +func (r *RoomserverInternalAPI) PerformCreateRoom( + ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest, +) (string, *util.JSONResponse) { + return r.Creator.PerformCreateRoom(ctx, userID, roomID, createRequest) +} + func (r *RoomserverInternalAPI) PerformInvite( ctx context.Context, req *api.PerformInviteRequest, diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 02a1a2802..386083f6e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -872,7 +872,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r return err } - event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 8d21b7829..575525e21 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -119,7 +119,7 @@ func (r *Admin) PerformAdminEvacuateRoom( continue } - event, err = eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes) + event, err = eventutil.BuildEvent(ctx, fledglingEvent, identity, time.Now(), &eventsNeeded, latestRes) if err != nil { return nil, err } @@ -312,7 +312,7 @@ func (r *Admin) PerformAdminDownloadState( return err } - ev, err := eventutil.BuildEvent(ctx, proto, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes) + ev, err := eventutil.BuildEvent(ctx, proto, identity, time.Now(), &eventsNeeded, queryRes) if err != nil { return fmt.Errorf("eventutil.BuildEvent: %w", err) } diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go new file mode 100644 index 000000000..0f9170087 --- /dev/null +++ b/roomserver/internal/perform/perform_create_room.go @@ -0,0 +1,498 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" +) + +const ( + historyVisibilityShared = "shared" +) + +type Creator struct { + DB storage.Database + Cfg *config.RoomServer + RSAPI api.RoomserverInternalAPI +} + +// PerformCreateRoom handles all the steps necessary to create a new room. +// nolint: gocyclo +func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) { + verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion) + if err != nil { + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("unknown room version"), + } + } + + createContent := map[string]interface{}{} + if len(createRequest.CreationContent) > 0 { + if err = json.Unmarshal(createRequest.CreationContent, &createContent); err != nil { + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed") + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("invalid create content"), + } + } + } + createContent["creator"] = userID.String() + createContent["room_version"] = createRequest.RoomVersion + powerLevelContent := eventutil.InitialPowerLevelsContent(userID.String()) + joinRuleContent := gomatrixserverlib.JoinRuleContent{ + JoinRule: spec.Invite, + } + historyVisibilityContent := gomatrixserverlib.HistoryVisibilityContent{ + HistoryVisibility: historyVisibilityShared, + } + + if createRequest.PowerLevelContentOverride != nil { + // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults + err = json.Unmarshal(createRequest.PowerLevelContentOverride, &powerLevelContent) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed") + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("malformed power_level_content_override"), + } + } + } + + var guestsCanJoin bool + switch createRequest.StatePreset { + case spec.PresetPrivateChat: + joinRuleContent.JoinRule = spec.Invite + historyVisibilityContent.HistoryVisibility = historyVisibilityShared + guestsCanJoin = true + case spec.PresetTrustedPrivateChat: + joinRuleContent.JoinRule = spec.Invite + historyVisibilityContent.HistoryVisibility = historyVisibilityShared + for _, invitee := range createRequest.InvitedUsers { + powerLevelContent.Users[invitee] = 100 + } + guestsCanJoin = true + case spec.PresetPublicChat: + joinRuleContent.JoinRule = spec.Public + historyVisibilityContent.HistoryVisibility = historyVisibilityShared + } + + createEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomCreate, + Content: createContent, + } + powerLevelEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomPowerLevels, + Content: powerLevelContent, + } + joinRuleEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomJoinRules, + Content: joinRuleContent, + } + historyVisibilityEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomHistoryVisibility, + Content: historyVisibilityContent, + } + membershipEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomMember, + StateKey: userID.String(), + Content: gomatrixserverlib.MemberContent{ + Membership: spec.Join, + DisplayName: createRequest.UserDisplayName, + AvatarURL: createRequest.UserAvatarURL, + }, + } + + var nameEvent *gomatrixserverlib.FledglingEvent + var topicEvent *gomatrixserverlib.FledglingEvent + var guestAccessEvent *gomatrixserverlib.FledglingEvent + var aliasEvent *gomatrixserverlib.FledglingEvent + + if createRequest.RoomName != "" { + nameEvent = &gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomName, + Content: eventutil.NameContent{ + Name: createRequest.RoomName, + }, + } + } + + if createRequest.Topic != "" { + topicEvent = &gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomTopic, + Content: eventutil.TopicContent{ + Topic: createRequest.Topic, + }, + } + } + + if guestsCanJoin { + guestAccessEvent = &gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomGuestAccess, + Content: eventutil.GuestAccessContent{ + GuestAccess: "can_join", + }, + } + } + + var roomAlias string + if createRequest.RoomAliasName != "" { + roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, userID.Domain()) + // check it's free + // TODO: This races but is better than nothing + hasAliasReq := api.GetRoomIDForAliasRequest{ + Alias: roomAlias, + IncludeAppservices: false, + } + + var aliasResp api.GetRoomIDForAliasResponse + err = c.RSAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if aliasResp.RoomID != "" { + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.RoomInUse("Room ID already exists."), + } + } + + aliasEvent = &gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomCanonicalAlias, + Content: eventutil.CanonicalAlias{ + Alias: roomAlias, + }, + } + } + + var initialStateEvents []gomatrixserverlib.FledglingEvent + for i := range createRequest.InitialState { + if createRequest.InitialState[i].StateKey != "" { + initialStateEvents = append(initialStateEvents, createRequest.InitialState[i]) + continue + } + + switch createRequest.InitialState[i].Type { + case spec.MRoomCreate: + continue + + case spec.MRoomPowerLevels: + powerLevelEvent = createRequest.InitialState[i] + + case spec.MRoomJoinRules: + joinRuleEvent = createRequest.InitialState[i] + + case spec.MRoomHistoryVisibility: + historyVisibilityEvent = createRequest.InitialState[i] + + case spec.MRoomGuestAccess: + guestAccessEvent = &createRequest.InitialState[i] + + case spec.MRoomName: + nameEvent = &createRequest.InitialState[i] + + case spec.MRoomTopic: + topicEvent = &createRequest.InitialState[i] + + default: + initialStateEvents = append(initialStateEvents, createRequest.InitialState[i]) + } + } + + // send events into the room in order of: + // 1- m.room.create + // 2- room creator join member + // 3- m.room.power_levels + // 4- m.room.join_rules + // 5- m.room.history_visibility + // 6- m.room.canonical_alias (opt) + // 7- m.room.guest_access (opt) + // 8- other initial state items + // 9- m.room.name (opt) + // 10- m.room.topic (opt) + // 11- invite events (opt) - with is_direct flag if applicable TODO + // 12- 3pid invite events (opt) TODO + // This differs from Synapse slightly. Synapse would vary the ordering of 3-7 + // depending on if those events were in "initial_state" or not. This made it + // harder to reason about, hence sticking to a strict static ordering. + // TODO: Synapse has txn/token ID on each event. Do we need to do this here? + eventsToMake := []gomatrixserverlib.FledglingEvent{ + createEvent, membershipEvent, powerLevelEvent, joinRuleEvent, historyVisibilityEvent, + } + if guestAccessEvent != nil { + eventsToMake = append(eventsToMake, *guestAccessEvent) + } + eventsToMake = append(eventsToMake, initialStateEvents...) + if nameEvent != nil { + eventsToMake = append(eventsToMake, *nameEvent) + } + if topicEvent != nil { + eventsToMake = append(eventsToMake, *topicEvent) + } + if aliasEvent != nil { + // TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room. + // This means we might fail creating the alias but say the canonical alias is something that doesn't exist. + eventsToMake = append(eventsToMake, *aliasEvent) + } + + // TODO: invite events + // TODO: 3pid invite events + + var builtEvents []*types.HeaderedEvent + authEvents := gomatrixserverlib.NewAuthEvents(nil) + for i, e := range eventsToMake { + depth := i + 1 // depth starts at 1 + + builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ + Sender: userID.String(), + RoomID: roomID.String(), + Type: e.Type, + StateKey: &e.StateKey, + Depth: int64(depth), + }) + err = builder.SetContent(e.Content) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if i > 0 { + builder.PrevEvents = []string{builtEvents[i-1].EventID()} + } + var ev gomatrixserverlib.PDU + if err = builder.AddAuthEvents(&authEvents); err != nil { + util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + ev, err = builder.Build(createRequest.EventTime, userID.Domain(), createRequest.KeyID, createRequest.PrivateKey) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("buildEvent failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // Add the event to the list of auth events + builtEvents = append(builtEvents, &types.HeaderedEvent{PDU: ev}) + err = authEvents.AddEvent(ev) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + inputs := make([]api.InputRoomEvent, 0, len(builtEvents)) + for _, event := range builtEvents { + inputs = append(inputs, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: userID.Domain(), + SendAsServer: api.DoNotSendToOtherServers, + }) + } + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs, false); err != nil { + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // TODO(#269): Reserve room alias while we create the room. This stops us + // from creating the room but still failing due to the alias having already + // been taken. + if roomAlias != "" { + aliasReq := api.SetRoomAliasRequest{ + Alias: roomAlias, + RoomID: roomID.String(), + UserID: userID.String(), + } + + var aliasResp api.SetRoomAliasResponse + err = c.RSAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + if aliasResp.AliasExists { + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.RoomInUse("Room alias already exists."), + } + } + } + + // If this is a direct message then we should invite the participants. + if len(createRequest.InvitedUsers) > 0 { + // Build some stripped state for the invite. + var globalStrippedState []fclient.InviteV2StrippedState + for _, event := range builtEvents { + // Chosen events from the spec: + // https://spec.matrix.org/v1.3/client-server-api/#stripped-state + switch event.Type() { + case spec.MRoomCreate: + fallthrough + case spec.MRoomName: + fallthrough + case spec.MRoomAvatar: + fallthrough + case spec.MRoomTopic: + fallthrough + case spec.MRoomCanonicalAlias: + fallthrough + case spec.MRoomEncryption: + fallthrough + case spec.MRoomMember: + fallthrough + case spec.MRoomJoinRules: + ev := event.PDU + globalStrippedState = append( + globalStrippedState, + fclient.NewInviteV2StrippedState(ev), + ) + } + } + + // Process the invites. + var inviteEvent *types.HeaderedEvent + for _, invitee := range createRequest.InvitedUsers { + proto := gomatrixserverlib.ProtoEvent{ + Sender: userID.String(), + RoomID: roomID.String(), + Type: "m.room.member", + StateKey: &invitee, + } + + content := gomatrixserverlib.MemberContent{ + Membership: spec.Invite, + DisplayName: createRequest.UserDisplayName, + AvatarURL: createRequest.UserAvatarURL, + Reason: "", + IsDirect: createRequest.IsDirect, + } + + if err = proto.SetContent(content); err != nil { + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // Build the invite event. + identity := &fclient.SigningIdentity{ + ServerName: userID.Domain(), + KeyID: createRequest.KeyID, + PrivateKey: createRequest.PrivateKey, + } + inviteEvent, err = eventutil.QueryAndBuildEvent(ctx, &proto, identity, createRequest.EventTime, c.RSAPI, nil) + + if err != nil { + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") + continue + } + inviteStrippedState := append( + globalStrippedState, + fclient.NewInviteV2StrippedState(inviteEvent.PDU), + ) + // Send the invite event to the roomserver. + event := inviteEvent + err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{ + Event: event, + InviteRoomState: inviteStrippedState, + RoomVersion: event.Version(), + SendAsServer: string(userID.Domain()), + }) + switch e := err.(type) { + case api.ErrInvalidID: + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown(e.Error()), + } + case api.ErrNotAllowed: + return "", &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden(e.Error()), + } + case nil: + default: + util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") + sentry.CaptureException(err) + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + } + + if createRequest.Visibility == spec.Public { + // expose this room in the published room list + if err = c.RSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ + RoomID: roomID.String(), + Visibility: spec.Public, + }); err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to publish room") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + // TODO: visibility/presets/raw initial state + // TODO: Create room alias association + // Make sure this doesn't fall into an application service's namespace though! + + return roomAlias, nil +} diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 5f4ad1861..34bea5b6d 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -284,7 +284,7 @@ func (r *Joiner) performJoinRoomByID( if err != nil { return "", "", fmt.Errorf("error joining local room: %q", err) } - event, err := eventutil.QueryAndBuildEvent(ctx, &proto, r.Cfg.Matrix, identity, time.Now(), r.RSAPI, &buildRes) + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes) switch err.(type) { case nil: diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index e71b3e908..90102aeeb 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -183,7 +183,7 @@ func (r *Leaver) performLeaveRoomByID( if err != nil { return nil, fmt.Errorf("SigningIdentityFor: %w", err) } - event, err := eventutil.QueryAndBuildEvent(ctx, &proto, r.Cfg.Matrix, identity, time.Now(), r.RSAPI, &buildRes) + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes) if err != nil { return nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err) } diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 60085cb6d..ff4a6a1dc 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -35,13 +35,6 @@ type Upgrader struct { URSAPI api.RoomserverInternalAPI } -// fledglingEvent is a helper representation of an event used when creating many events in succession. -type fledglingEvent struct { - Type string `json:"type"` - StateKey string `json:"state_key"` - Content interface{} `json:"content"` -} - // PerformRoomUpgrade upgrades a room from one version to another func (r *Upgrader) PerformRoomUpgrade( ctx context.Context, @@ -154,7 +147,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel - restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, fledglingEvent{ + restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ Type: spec.MRoomPowerLevels, StateKey: "", Content: restrictedPowerLevelContent, @@ -216,7 +209,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api } } - emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, fledglingEvent{ + emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ Type: spec.MRoomCanonicalAlias, Content: map[string]interface{}{}, }) @@ -298,7 +291,7 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, } // nolint:gocyclo -func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]fledglingEvent, error) { +func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) { state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents)) for _, event := range oldRoom.StateEvents { if event.StateKey() == nil { @@ -361,7 +354,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query EventID: tombstoneEvent.EventID(), RoomID: roomID, } - newCreateEvent := fledglingEvent{ + newCreateEvent := gomatrixserverlib.FledglingEvent{ Type: spec.MRoomCreate, StateKey: "", Content: newCreateContent, @@ -374,7 +367,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query newMembershipContent := map[string]interface{}{} _ = json.Unmarshal(oldMembershipEvent.Content(), &newMembershipContent) newMembershipContent["membership"] = spec.Join - newMembershipEvent := fledglingEvent{ + newMembershipEvent := gomatrixserverlib.FledglingEvent{ Type: spec.MRoomMember, StateKey: userID, Content: newMembershipContent, @@ -400,13 +393,13 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query "join_rule": spec.Invite, // sane default } _ = json.Unmarshal(oldJoinRulesEvent.Content(), &newJoinRulesContent) - newJoinRulesEvent := fledglingEvent{ + newJoinRulesEvent := gomatrixserverlib.FledglingEvent{ Type: spec.MRoomJoinRules, StateKey: "", Content: newJoinRulesContent, } - eventsToMake := make([]fledglingEvent, 0, len(state)) + eventsToMake := make([]gomatrixserverlib.FledglingEvent, 0, len(state)) eventsToMake = append( eventsToMake, newCreateEvent, newMembershipEvent, tempPowerLevelsEvent, newJoinRulesEvent, @@ -415,7 +408,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // For some reason Sytest expects there to be a guest access event. // Create one if it doesn't exist. if _, ok := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomGuestAccess, StateKey: ""}]; !ok { - eventsToMake = append(eventsToMake, fledglingEvent{ + eventsToMake = append(eventsToMake, gomatrixserverlib.FledglingEvent{ Type: spec.MRoomGuestAccess, Content: map[string]string{ "guest_access": "forbidden", @@ -430,7 +423,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // are already in `eventsToMake`. continue } - newEvent := fledglingEvent{ + newEvent := gomatrixserverlib.FledglingEvent{ Type: tuple.EventType, StateKey: tuple.StateKey, } @@ -444,7 +437,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // If we sent a temporary power level event into the room before, // override that now by restoring the original power levels. if powerLevelsOverridden { - eventsToMake = append(eventsToMake, fledglingEvent{ + eventsToMake = append(eventsToMake, gomatrixserverlib.FledglingEvent{ Type: spec.MRoomPowerLevels, Content: powerLevelContent, }) @@ -452,7 +445,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query return eventsToMake, nil } -func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []fledglingEvent) error { +func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { var err error var builtEvents []*types.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) @@ -527,14 +520,14 @@ func (r *Upgrader) makeTombstoneEvent( "body": "This room has been replaced", "replacement_room": newRoomID, } - event := fledglingEvent{ + event := gomatrixserverlib.FledglingEvent{ Type: "m.room.tombstone", Content: content, } return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event) } -func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event fledglingEvent) (*types.HeaderedEvent, error) { +func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { proto := gomatrixserverlib.ProtoEvent{ Sender: userID, RoomID: roomID, @@ -555,7 +548,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) } var queryRes api.QueryLatestEventsAndStateResponse - headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &proto, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes) + headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, evTime, r.URSAPI, &queryRes) switch e := err.(type) { case nil: case eventutil.ErrRoomNoExists: @@ -581,7 +574,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user return headeredEvent, nil } -func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, userID string) (fledglingEvent, bool) { +func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, userID string) (gomatrixserverlib.FledglingEvent, bool) { // Work out what power level we need in order to be able to send events // of all types into the room. neededPowerLevel := powerLevelContent.StateDefault @@ -612,7 +605,7 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC } // Then return the temporary power levels event. - return fledglingEvent{ + return gomatrixserverlib.FledglingEvent{ Type: spec.MRoomPowerLevels, Content: tempPowerLevelContent, }, powerLevelsOverridden From ea6b368ad424a3d2e05135afb7fd0c0801b3609b Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 31 May 2023 16:33:49 +0000 Subject: [PATCH 13/35] Move Invite logic to GMSL (#3086) This is both the federation receiving & sending side logic (which were previously entangeld in a single function) --- clientapi/routing/createroom.go | 1 - federationapi/api/api.go | 8 +- federationapi/internal/perform.go | 47 ++- federationapi/routing/invite.go | 259 +++++++------- federationapi/routing/join.go | 21 +- federationapi/routing/routing.go | 20 +- go.mod | 2 +- go.sum | 4 +- roomserver/api/api.go | 5 + roomserver/api/perform.go | 11 +- roomserver/api/query.go | 20 ++ roomserver/internal/api.go | 29 +- roomserver/internal/helpers/auth.go | 38 +- .../internal/perform/perform_create_room.go | 6 +- roomserver/internal/perform/perform_invite.go | 330 +++++++----------- 15 files changed, 359 insertions(+), 442 deletions(-) diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index aaa305f06..799fc7976 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -135,7 +135,6 @@ func CreateRoom( // createRoom implements /createRoom func createRoom( ctx context.Context, - // TODO: remove dependency on createRoomRequest createRequest createRoomRequest, device *api.Device, cfg *config.ClientAPI, profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, diff --git a/federationapi/api/api.go b/federationapi/api/api.go index b53ec3dd8..5b49e509e 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -62,7 +62,7 @@ type RoomserverFederationAPI interface { // Handle an instruction to make_leave & send_leave with a remote server. PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse) error // Handle sending an invite to a remote server. - PerformInvite(ctx context.Context, request *PerformInviteRequest, response *PerformInviteResponse) error + SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error) // Handle an instruction to peek a room on a remote server. PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error // Query the server names of the joined hosts in a room. @@ -190,9 +190,9 @@ type PerformLeaveResponse struct { } type PerformInviteRequest struct { - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event *rstypes.HeaderedEvent `json:"event"` - InviteRoomState []fclient.InviteV2StrippedState `json:"invite_room_state"` + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + Event *rstypes.HeaderedEvent `json:"event"` + InviteRoomState []gomatrixserverlib.InviteStrippedState `json:"invite_room_state"` } type PerformInviteResponse struct { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 99943d86c..ed800d03a 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -503,60 +503,59 @@ func (r *FederationInternalAPI) PerformLeave( ) } -// PerformLeaveRequest implements api.FederationInternalAPI -func (r *FederationInternalAPI) PerformInvite( +// SendInvite implements api.FederationInternalAPI +func (r *FederationInternalAPI) SendInvite( ctx context.Context, - request *api.PerformInviteRequest, - response *api.PerformInviteResponse, -) (err error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender()) + event gomatrixserverlib.PDU, + strippedState []gomatrixserverlib.InviteStrippedState, +) (gomatrixserverlib.PDU, error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', event.Sender()) if err != nil { - return err + return nil, err } - if request.Event.StateKey() == nil { - return errors.New("invite must be a state event") + if event.StateKey() == nil { + return nil, errors.New("invite must be a state event") } - _, destination, err := gomatrixserverlib.SplitID('@', *request.Event.StateKey()) + _, destination, err := gomatrixserverlib.SplitID('@', *event.StateKey()) if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } // TODO (devon): This should be allowed via a relay. Currently only transactions // can be sent to relays. Would need to extend relays to handle invites. if !r.shouldAttemptDirectFederation(destination) { - return fmt.Errorf("relay servers have no meaningful response for invite.") + return nil, fmt.Errorf("relay servers have no meaningful response for invite.") } logrus.WithFields(logrus.Fields{ - "event_id": request.Event.EventID(), - "user_id": *request.Event.StateKey(), - "room_id": request.Event.RoomID(), - "room_version": request.RoomVersion, + "event_id": event.EventID(), + "user_id": *event.StateKey(), + "room_id": event.RoomID(), + "room_version": event.Version(), "destination": destination, }).Info("Sending invite") - inviteReq, err := fclient.NewInviteV2Request(request.Event.PDU, request.InviteRoomState) + inviteReq, err := fclient.NewInviteV2Request(event, strippedState) if err != nil { - return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) + return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) if err != nil { - return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) + return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } - verImpl, err := gomatrixserverlib.GetRoomVersion(request.RoomVersion) + verImpl, err := gomatrixserverlib.GetRoomVersion(event.Version()) if err != nil { - return err + return nil, err } inviteEvent, err := verImpl.NewEventFromUntrustedJSON(inviteRes.Event) if err != nil { - return fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err) + return nil, fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err) } - response.Event = &types.HeaderedEvent{PDU: inviteEvent} - return nil + return inviteEvent, nil } // PerformServersAlive implements api.FederationInternalAPI diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 993d40466..78a09d949 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -20,7 +20,6 @@ import ( "fmt" "net/http" - "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" @@ -34,7 +33,7 @@ import ( func InviteV2( httpReq *http.Request, request *fclient.FederationRequest, - roomID string, + roomID spec.RoomID, eventID string, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, @@ -56,9 +55,55 @@ func InviteV2( JSON: spec.BadJSON(err.Error()), } case nil: - return processInvite( - httpReq.Context(), true, inviteReq.Event(), inviteReq.RoomVersion(), inviteReq.InviteRoomState(), roomID, eventID, cfg, rsAPI, keys, - ) + if inviteReq.Event().StateKey() == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("The invite event has no state key"), + } + } + + invitedUser, userErr := spec.NewUserID(*inviteReq.Event().StateKey(), true) + if userErr != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("The user ID is invalid"), + } + } + if !cfg.Matrix.IsLocalServerName(invitedUser.Domain()) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("The invited user domain does not belong to this server"), + } + } + + if inviteReq.Event().EventID() != eventID { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("The event ID in the request path must match the event ID in the invite event JSON"), + } + } + + input := gomatrixserverlib.HandleInviteInput{ + RoomVersion: inviteReq.RoomVersion(), + RoomID: roomID, + InvitedUser: *invitedUser, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + RoomQuerier: rsAPI, + MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + StateQuerier: rsAPI.StateQuerier(), + InviteEvent: inviteReq.Event(), + StrippedState: inviteReq.InviteRoomState(), + } + event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) + if jsonErr != nil { + return *jsonErr + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: fclient.RespInviteV2{Event: event.JSON()}, + } default: return util.JSONResponse{ Code: http.StatusBadRequest, @@ -71,7 +116,7 @@ func InviteV2( func InviteV1( httpReq *http.Request, request *fclient.FederationRequest, - roomID string, + roomID spec.RoomID, eventID string, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, @@ -94,55 +139,11 @@ func InviteV1( JSON: spec.NotJSON("The request body could not be decoded into an invite v1 request. " + err.Error()), } } - var strippedState []fclient.InviteV2StrippedState - if err := json.Unmarshal(event.Unsigned(), &strippedState); err != nil { + var strippedState []gomatrixserverlib.InviteStrippedState + if jsonErr := json.Unmarshal(event.Unsigned(), &strippedState); jsonErr != nil { // just warn, they may not have added any. util.GetLogger(httpReq.Context()).Warnf("failed to extract stripped state from invite event") } - return processInvite( - httpReq.Context(), false, event, roomVer, strippedState, roomID, eventID, cfg, rsAPI, keys, - ) -} - -func processInvite( - ctx context.Context, - isInviteV2 bool, - event gomatrixserverlib.PDU, - roomVer gomatrixserverlib.RoomVersion, - strippedState []fclient.InviteV2StrippedState, - roomID string, - eventID string, - cfg *config.FederationAPI, - rsAPI api.FederationRoomserverAPI, - keys gomatrixserverlib.JSONVerifier, -) util.JSONResponse { - - // Check that we can accept invites for this room version. - verImpl, err := gomatrixserverlib.GetRoomVersion(roomVer) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.UnsupportedRoomVersion( - fmt.Sprintf("Room version %q is not supported by this server.", roomVer), - ), - } - } - - // Check that the room ID is correct. - if event.RoomID() != roomID { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("The room ID in the request path must match the room ID in the invite event JSON"), - } - } - - // Check that the event ID is correct. - if event.EventID() != eventID { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("The event ID in the request path must match the event ID in the invite event JSON"), - } - } if event.StateKey() == nil { return util.JSONResponse{ @@ -151,105 +152,91 @@ func processInvite( } } - _, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey()) + invitedUser, err := spec.NewUserID(*event.StateKey(), true) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.InvalidParam(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)), + JSON: spec.InvalidParam("The user ID is invalid"), } } - - // Check that the event is signed by the server sending the request. - redacted, err := verImpl.RedactEventJSON(event.JSON()) - if err != nil { + if !cfg.Matrix.IsLocalServerName(invitedUser.Domain()) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("The event JSON could not be redacted"), + JSON: spec.InvalidParam("The invited user domain does not belong to this server"), } } - _, serverName, err := gomatrixserverlib.SplitID('@', event.Sender()) - if err != nil { + + if event.EventID() != eventID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("The event JSON contains an invalid sender"), - } - } - verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, - }} - verifyResults, err := keys.VerifyJSONs(ctx, verifyRequests) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("keys.VerifyJSONs failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - if verifyResults[0].Error != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("The invite must be signed by the server it originated on"), + JSON: spec.BadJSON("The event ID in the request path must match the event ID in the invite event JSON"), } } - // Sign the event so that other servers will know that we have received the invite. - signedEvent := event.Sign( - string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, - ) - - // Add the invite event to the roomserver. - inviteEvent := &types.HeaderedEvent{PDU: signedEvent} - request := &api.PerformInviteRequest{ - Event: inviteEvent, - InviteRoomState: strippedState, - RoomVersion: inviteEvent.Version(), - SendAsServer: string(api.DoNotSendToOtherServers), - TransactionID: nil, + input := gomatrixserverlib.HandleInviteInput{ + RoomVersion: roomVer, + RoomID: roomID, + InvitedUser: *invitedUser, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + RoomQuerier: rsAPI, + MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + StateQuerier: rsAPI.StateQuerier(), + InviteEvent: event, + StrippedState: strippedState, } - - if err = rsAPI.PerformInvite(ctx, request); err != nil { - util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } + event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) + if jsonErr != nil { + return *jsonErr } - - switch e := err.(type) { - case api.ErrInvalidID: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown(e.Error()), - } - case api.ErrNotAllowed: - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden(e.Error()), - } - case nil: - default: - util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") - sentry.CaptureException(err) - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - - // Return the signed event to the originating server, it should then tell - // the other servers in the room that we have been invited. - if isInviteV2 { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: fclient.RespInviteV2{Event: signedEvent.JSON()}, - } - } else { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: fclient.RespInvite{Event: signedEvent.JSON()}, - } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: fclient.RespInvite{Event: event.JSON()}, } } + +func handleInvite(ctx context.Context, input gomatrixserverlib.HandleInviteInput, rsAPI api.FederationRoomserverAPI) (gomatrixserverlib.PDU, *util.JSONResponse) { + inviteEvent, err := gomatrixserverlib.HandleInvite(ctx, input) + switch e := err.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(ctx).WithError(err) + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + case spec.MatrixError: + util.GetLogger(ctx).WithError(err) + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorUnsupportedRoomVersion: + fallthrough // http.StatusBadRequest + case spec.ErrorBadJSON: + code = http.StatusBadRequest + } + + return nil, &util.JSONResponse{ + Code: code, + JSON: e, + } + default: + util.GetLogger(ctx).WithError(err) + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), + } + } + + headeredInvite := &types.HeaderedEvent{PDU: inviteEvent} + if err = rsAPI.HandleInvite(ctx, headeredInvite); err != nil { + util.GetLogger(ctx).WithError(err).Error("HandleInvite failed") + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + return inviteEvent, nil +} diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 03d3309ae..c6f96375e 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -216,25 +216,6 @@ func MakeJoin( } } -type MembershipQuerier struct { - roomserver api.FederationRoomserverAPI -} - -func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { - req := api.QueryMembershipForUserRequest{ - RoomID: roomID.String(), - UserID: userID.String(), - } - res := api.QueryMembershipForUserResponse{} - err := mq.roomserver.QueryMembershipForUser(ctx, &req, &res) - - membership := "" - if err == nil { - membership = res.Membership - } - return membership, err -} - // SendJoin implements the /send_join API // The make-join send-join dance makes much more sense as a single // flow so the cyclomatic complexity is high: @@ -268,7 +249,7 @@ func SendJoin( KeyID: cfg.Matrix.KeyID, PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, - MembershipQuerier: &MembershipQuerier{roomserver: rsAPI}, + MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, } response, joinErr := gomatrixserverlib.HandleSendJoin(input) switch e := joinErr.(type) { diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index fad06c1cf..8865022ff 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -152,8 +152,16 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } + + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } return InviteV1( - httpReq, request, vars["roomID"], vars["eventID"], + httpReq, request, *roomID, vars["eventID"], cfg, rsAPI, keys, ) }, @@ -168,8 +176,16 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } + + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } return InviteV2( - httpReq, request, vars["roomID"], vars["eventID"], + httpReq, request, *roomID, vars["eventID"], cfg, rsAPI, keys, ) }, diff --git a/go.mod b/go.mod index 0e979de6f..a20757bbc 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230531143710-c681a0658246 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 8baa50e85..a1946adaa 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230531143710-c681a0658246 h1:1sYXx7p9BIf0R7OIV/TZg3SCvNehEQPCKNqwV1ONfwU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230531143710-c681a0658246/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 h1:Kh1TNvJDhWN5CdgtICNUC4G0wV2km51LGr46Dvl153A= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 571aa40b3..7cb3379e0 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -225,6 +225,8 @@ type FederationRoomserverAPI interface { QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error + HandleInvite(ctx context.Context, event *types.HeaderedEvent) error + PerformInvite(ctx context.Context, req *PerformInviteRequest) error // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error @@ -234,6 +236,9 @@ type FederationRoomserverAPI interface { QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) + + IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) + StateQuerier() gomatrixserverlib.StateQuerier } type KeyserverRoomserverAPI interface { diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 8d9742c69..6cbaf5b19 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -7,7 +7,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -52,11 +51,11 @@ type PerformLeaveResponse struct { } type PerformInviteRequest struct { - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event *types.HeaderedEvent `json:"event"` - InviteRoomState []fclient.InviteV2StrippedState `json:"invite_room_state"` - SendAsServer string `json:"send_as_server"` - TransactionID *TransactionID `json:"transaction_id"` + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + Event *types.HeaderedEvent `json:"event"` + InviteRoomState []gomatrixserverlib.InviteStrippedState `json:"invite_room_state"` + SendAsServer string `json:"send_as_server"` + TransactionID *TransactionID `json:"transaction_id"` } type PerformPeekRequest struct { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 1726bfe1f..b33698c82 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -17,6 +17,7 @@ package api import ( + "context" "encoding/json" "fmt" "strings" @@ -457,3 +458,22 @@ type QueryLeftUsersRequest struct { type QueryLeftUsersResponse struct { LeftUsers []string `json:"user_ids"` } + +type MembershipQuerier struct { + Roomserver FederationRoomserverAPI +} + +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { + req := QueryMembershipForUserRequest{ + RoomID: roomID.String(), + UserID: userID.String(), + } + res := QueryMembershipForUserResponse{} + err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) + + membership := "" + if err == nil { + membership = res.Membership + } + return membership, err +} diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index f61f89183..ee433f0d2 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -132,6 +133,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio DB: r.DB, Cfg: &r.Cfg.RoomServer, FSAPI: r.fsAPI, + RSAPI: r, Inputer: r.Inputer, } r.Joiner = &perform.Joiner{ @@ -213,6 +215,24 @@ func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalA r.asAPI = asAPI } +func (r *RoomserverInternalAPI) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) { + return r.Inviter.IsKnownRoom(ctx, roomID) +} + +func (r *RoomserverInternalAPI) StateQuerier() gomatrixserverlib.StateQuerier { + return r.Inviter.StateQuerier() +} + +func (r *RoomserverInternalAPI) HandleInvite( + ctx context.Context, inviteEvent *types.HeaderedEvent, +) error { + outputEvents, err := r.Inviter.ProcessInviteMembership(ctx, inviteEvent) + if err != nil { + return err + } + return r.OutputProducer.ProduceRoomEvents(inviteEvent.RoomID(), outputEvents) +} + func (r *RoomserverInternalAPI) PerformCreateRoom( ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest, ) (string, *util.JSONResponse) { @@ -223,14 +243,7 @@ func (r *RoomserverInternalAPI) PerformInvite( ctx context.Context, req *api.PerformInviteRequest, ) error { - outputEvents, err := r.Inviter.PerformInvite(ctx, req) - if err != nil { - return err - } - if len(outputEvents) == 0 { - return nil - } - return r.OutputProducer.ProduceRoomEvents(req.Event.RoomID(), outputEvents) + return r.Inviter.PerformInvite(ctx, req) } func (r *RoomserverInternalAPI) PerformLeave( diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 24958091b..7ec0892e4 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -70,7 +70,7 @@ func CheckForSoftFail( ) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomVersion, stateNeeded, authStateEntries) if err != nil { return true, fmt.Errorf("loadAuthEvents: %w", err) } @@ -83,15 +83,14 @@ func CheckForSoftFail( return false, nil } -// CheckAuthEvents checks that the event passes authentication checks -// Returns the numeric IDs for the auth events. -func CheckAuthEvents( +// GetAuthEvents returns the numeric IDs for the auth events. +func GetAuthEvents( ctx context.Context, db storage.RoomDatabase, - roomInfo *types.RoomInfo, - event *types.HeaderedEvent, + roomVersion gomatrixserverlib.RoomVersion, + event gomatrixserverlib.PDU, authEventIDs []string, -) ([]types.EventNID, error) { +) (gomatrixserverlib.AuthEventProvider, error) { // Grab the numeric IDs for the supplied auth state events from the database. authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs, true) if err != nil { @@ -100,25 +99,14 @@ func CheckAuthEvents( authStateEntries = types.DeduplicateStateEntries(authStateEntries) // Work out which of the state events we actually need. - stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event.PDU}) + stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomVersion, stateNeeded, authStateEntries) if err != nil { return nil, fmt.Errorf("loadAuthEvents: %w", err) } - - // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil { - return nil, err - } - - // Return the numeric IDs for the auth events. - result := make([]types.EventNID, len(authStateEntries)) - for i := range authStateEntries { - result[i] = authStateEntries[i].EventNID - } - return result, nil + return &authEvents, nil } type authEvents struct { @@ -196,7 +184,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) g func loadAuthEvents( ctx context.Context, db state.StateResolutionStorage, - roomInfo *types.RoomInfo, + roomVersion gomatrixserverlib.RoomVersion, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { @@ -220,11 +208,7 @@ func loadAuthEvents( } } - if roomInfo == nil { - err = types.ErrorInvalidRoomInfo - return - } - if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil { + if result.events, err = db.Events(ctx, roomVersion, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 0f9170087..41194832d 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -376,7 +376,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo // If this is a direct message then we should invite the participants. if len(createRequest.InvitedUsers) > 0 { // Build some stripped state for the invite. - var globalStrippedState []fclient.InviteV2StrippedState + var globalStrippedState []gomatrixserverlib.InviteStrippedState for _, event := range builtEvents { // Chosen events from the spec: // https://spec.matrix.org/v1.3/client-server-api/#stripped-state @@ -399,7 +399,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo ev := event.PDU globalStrippedState = append( globalStrippedState, - fclient.NewInviteV2StrippedState(ev), + gomatrixserverlib.NewInviteStrippedState(ev), ) } } @@ -443,7 +443,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } inviteStrippedState := append( globalStrippedState, - fclient.NewInviteV2StrippedState(inviteEvent.PDU), + gomatrixserverlib.NewInviteStrippedState(inviteEvent.PDU), ) // Send the invite event to the roomserver. event := inviteEvent diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index a3fa2e011..1930b5ace 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -28,186 +28,149 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - log "github.com/sirupsen/logrus" ) +type QueryState struct { + storage.Database +} + +func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) { + return helpers.GetAuthEvents(ctx, q.Database, event.Version(), event, event.AuthEventIDs()) +} + +func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple) ([]gomatrixserverlib.PDU, error) { + info, err := q.Database.RoomInfo(ctx, roomID.String()) + if err != nil { + return nil, fmt.Errorf("failed to load RoomInfo: %w", err) + } + if info != nil { + roomState := state.NewStateResolution(q.Database, info) + stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( + ctx, info.StateSnapshotNID(), stateWanted, + ) + if err != nil { + return nil, nil + } + stateNIDs := []types.EventNID{} + for _, stateNID := range stateEntries { + stateNIDs = append(stateNIDs, stateNID.EventNID) + } + stateEvents, err := q.Database.Events(ctx, info.RoomVersion, stateNIDs) + if err != nil { + return nil, fmt.Errorf("failed to obtain required events: %w", err) + } + + events := []gomatrixserverlib.PDU{} + for _, event := range stateEvents { + events = append(events, event.PDU) + } + return events, nil + } + + return nil, nil +} + type Inviter struct { DB storage.Database Cfg *config.RoomServer FSAPI federationAPI.RoomserverFederationAPI + RSAPI api.RoomserverInternalAPI Inputer *input.Inputer } +func (r *Inviter) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) { + info, err := r.DB.RoomInfo(ctx, roomID.String()) + if err != nil { + return false, fmt.Errorf("failed to load RoomInfo: %w", err) + } + return (info != nil && !info.IsStub()), nil +} + +func (r *Inviter) StateQuerier() gomatrixserverlib.StateQuerier { + return &QueryState{Database: r.DB} +} + +func (r *Inviter) ProcessInviteMembership( + ctx context.Context, inviteEvent *types.HeaderedEvent, +) ([]api.OutputEvent, error) { + var outputUpdates []api.OutputEvent + var updater *shared.MembershipUpdater + _, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey()) + if err != nil { + return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} + } + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { + return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) + } + outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{ + EventNID: 0, + PDU: inviteEvent.PDU, + }, outputUpdates, inviteEvent.Version()) + if err != nil { + return nil, fmt.Errorf("updateToInviteMembership: %w", err) + } + if err = updater.Commit(); err != nil { + return nil, fmt.Errorf("updater.Commit: %w", err) + } + return outputUpdates, nil +} + // nolint:gocyclo func (r *Inviter) PerformInvite( ctx context.Context, req *api.PerformInviteRequest, -) ([]api.OutputEvent, error) { - var outputUpdates []api.OutputEvent +) error { event := req.Event + + sender, err := spec.NewUserID(event.Sender(), true) + if err != nil { + return spec.InvalidParam("The user ID is invalid") + } + if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { + return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} + } + if event.StateKey() == nil { - return nil, fmt.Errorf("invite must be a state event") + return fmt.Errorf("invite must be a state event") } - _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + invitedUser, err := spec.NewUserID(*event.StateKey(), true) if err != nil { - return nil, fmt.Errorf("sender %q is invalid", event.Sender()) + return spec.InvalidParam("The user ID is invalid") } + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) - roomID := event.RoomID() - targetUserID := *event.StateKey() - info, err := r.DB.RoomInfo(ctx, roomID) + validRoomID, err := spec.NewRoomID(event.RoomID()) if err != nil { - return nil, fmt.Errorf("failed to load RoomInfo: %w", err) + return err } - _, domain, err := gomatrixserverlib.SplitID('@', targetUserID) + input := gomatrixserverlib.PerformInviteInput{ + RoomID: *validRoomID, + InviteEvent: event.PDU, + InvitedUser: *invitedUser, + IsTargetLocal: isTargetLocal, + StrippedState: req.InviteRoomState, + MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, + StateQuerier: &QueryState{r.DB}, + } + inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) if err != nil { - return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", targetUserID)} - } - isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) - isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain) - if !isOriginLocal && !isTargetLocal { - return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be either from or to a local user")} + switch e := err.(type) { + case spec.MatrixError: + if e.ErrCode == spec.ErrorForbidden { + return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)} + } + } + return err } - logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ - "inviter": event.Sender(), - "invitee": *event.StateKey(), - "room_id": roomID, - "event_id": event.EventID(), - }) - logger.WithFields(log.Fields{ - "room_version": req.RoomVersion, - "room_info_exists": info != nil, - "target_local": isTargetLocal, - "origin_local": isOriginLocal, - }).Debug("processing invite event") - - inviteState := req.InviteRoomState - if len(inviteState) == 0 && info != nil { - var is []fclient.InviteV2StrippedState - if is, err = buildInviteStrippedState(ctx, r.DB, info, req); err == nil { - inviteState = is - } - } - if len(inviteState) == 0 { - if err = event.SetUnsignedField("invite_room_state", struct{}{}); err != nil { - return nil, fmt.Errorf("event.SetUnsignedField: %w", err) - } - } else { - if err = event.SetUnsignedField("invite_room_state", inviteState); err != nil { - return nil, fmt.Errorf("event.SetUnsignedField: %w", err) - } - } - - updateMembershipTableManually := func() ([]api.OutputEvent, error) { - var updater *shared.MembershipUpdater - if updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion); err != nil { - return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) - } - outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{ - EventNID: 0, - PDU: event.PDU, - }, outputUpdates, req.Event.Version()) - if err != nil { - return nil, fmt.Errorf("updateToInviteMembership: %w", err) - } - if err = updater.Commit(); err != nil { - return nil, fmt.Errorf("updater.Commit: %w", err) - } - logger.Debugf("updated membership to invite and sending invite OutputEvent") - return outputUpdates, nil - } - - if (info == nil || info.IsStub()) && !isOriginLocal && isTargetLocal { - // The invite came in over federation for a room that we don't know about - // yet. We need to handle this a bit differently to most invites because - // we don't know the room state, therefore the roomserver can't process - // an input event. Instead we will update the membership table with the - // new invite and generate an output event. - return updateMembershipTableManually() - } - - var isAlreadyJoined bool - if info != nil { - _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) - if err != nil { - return nil, fmt.Errorf("r.DB.GetMembership: %w", err) - } - } - if isAlreadyJoined { - // If the user is joined to the room then that takes precedence over this - // invite event. It makes little sense to move a user that is already - // joined to the room into the invite state. - // This could plausibly happen if an invite request raced with a join - // request for a user. For example if a user was invited to a public - // room and they joined the room at the same time as the invite was sent. - // The other way this could plausibly happen is if an invite raced with - // a kick. For example if a user was kicked from a room in error and in - // response someone else in the room re-invited them then it is possible - // for the invite request to race with the leave event so that the - // target receives invite before it learns that it has been kicked. - // There are a few ways this could be plausibly handled in the roomserver. - // 1) Store the invite, but mark it as retired. That will result in the - // permanent rejection of that invite event. So even if the target - // user leaves the room and the invite is retransmitted it will be - // ignored. However a new invite with a new event ID would still be - // accepted. - // 2) Silently discard the invite event. This means that if the event - // was retransmitted at a later date after the target user had left - // the room we would accept the invite. However since we hadn't told - // the sending server that the invite had been discarded it would - // have no reason to attempt to retry. - // 3) Signal the sending server that the user is already joined to the - // room. - // For now we will implement option 2. Since in the abesence of a retry - // mechanism it will be equivalent to option 1, and we don't have a - // signalling mechanism to implement option 3. - logger.Debugf("user already joined") - return nil, api.ErrNotAllowed{Err: fmt.Errorf("user is already joined to room")} - } - - // If the invite originated remotely then we can't send an - // InputRoomEvent for the invite as it will never pass auth checks - // due to lacking room state, but we still need to tell the client - // about the invite so we can accept it, hence we return an output - // event to send to the Sync API. - if !isOriginLocal { - return updateMembershipTableManually() - } - - // The invite originated locally. Therefore we have a responsibility to - // try and see if the user is allowed to make this invite. We can't do - // this for invites coming in over federation - we have to take those on - // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs()) - if err != nil { - logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( - "processInviteEvent.checkAuthEvents failed for event", - ) - return nil, api.ErrNotAllowed{Err: err} - } - - // If the invite originated from us and the target isn't local then we - // should try and send the invite over federation first. It might be - // that the remote user doesn't exist, in which case we can give up - // processing here. - if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal { - fsReq := &federationAPI.PerformInviteRequest{ - RoomVersion: req.RoomVersion, - Event: event, - InviteRoomState: inviteState, - } - fsRes := &federationAPI.PerformInviteResponse{} - if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { - logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") - return nil, api.ErrNotAllowed{Err: err} - } - event = fsRes.Event - logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID()) + // Use the returned event if there was one (due to federation), otherwise + // send the original invite event to the roomserver. + if inviteEvent == nil { + inviteEvent = event } // Send the invite event to the roomserver input stream. This will @@ -219,67 +182,18 @@ func (r *Inviter) PerformInvite( InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, - Event: event, - Origin: senderDomain, + Event: &types.HeaderedEvent{PDU: inviteEvent}, + Origin: sender.Domain(), SendAsServer: req.SendAsServer, }, }, } inputRes := &api.InputRoomEventsResponse{} r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) - if err = inputRes.Err(); err != nil { - logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") - return nil, api.ErrNotAllowed{Err: err} + if err := inputRes.Err(); err != nil { + util.GetLogger(ctx).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") + return api.ErrNotAllowed{Err: err} } - // Don't notify the sync api of this event in the same way as a federated invite so the invitee - // gets the invite, as the roomserver will do this when it processes the m.room.member invite. - return outputUpdates, nil -} - -func buildInviteStrippedState( - ctx context.Context, - db storage.Database, - info *types.RoomInfo, - input *api.PerformInviteRequest, -) ([]fclient.InviteV2StrippedState, error) { - stateWanted := []gomatrixserverlib.StateKeyTuple{} - // "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included." - // https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member - for _, t := range []string{ - spec.MRoomName, spec.MRoomCanonicalAlias, - spec.MRoomJoinRules, spec.MRoomAvatar, - spec.MRoomEncryption, spec.MRoomCreate, - } { - stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{ - EventType: t, - StateKey: "", - }) - } - roomState := state.NewStateResolution(db, info) - stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( - ctx, info.StateSnapshotNID(), stateWanted, - ) - if err != nil { - return nil, err - } - stateNIDs := []types.EventNID{} - for _, stateNID := range stateEntries { - stateNIDs = append(stateNIDs, stateNID.EventNID) - } - if info == nil { - return nil, types.ErrorInvalidRoomInfo - } - stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs) - if err != nil { - return nil, err - } - inviteState := []fclient.InviteV2StrippedState{ - fclient.NewInviteV2StrippedState(input.Event.PDU), - } - stateEvents = append(stateEvents, types.Event{PDU: input.Event.PDU}) - for _, event := range stateEvents { - inviteState = append(inviteState, fclient.NewInviteV2StrippedState(event.PDU)) - } - return inviteState, nil + return nil } From d11da6ec7cc683864e1e10b7f47764d1bb0c4f1a Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 2 Jun 2023 15:48:04 +0200 Subject: [PATCH 14/35] Fix newly found linter issues (#3099) Fixes the issues found in https://github.com/matrix-org/dendrite/actions/runs/5155539352/jobs/9285342056#step:5:22. Only naked returns in longer functions. --- roomserver/internal/perform/perform_peek.go | 4 ++-- syncapi/routing/messages.go | 7 ++++--- syncapi/streams/stream_pdu.go | 4 ++-- userapi/storage/shared/storage.go | 4 ++-- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 661fe20a8..88fa2a431 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -157,7 +157,7 @@ func (r *Peeker) performPeekRoomByID( content := map[string]string{} if err = json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") - return + return "", err } if visibility, ok := content["history_visibility"]; ok { worldReadable = visibility == "world_readable" @@ -185,7 +185,7 @@ func (r *Peeker) performPeekRoomByID( }, }) if err != nil { - return + return "", err } // By this point, if req.RoomIDOrAlias contained an alias, then diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 58f663d0b..aeaec699b 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -314,11 +314,12 @@ func (r *messagesReq) retrieveEvents() ( clientEvents []synctypes.ClientEvent, start, end types.TopologyToken, err error, ) { + emptyToken := types.TopologyToken{} // Retrieve the events from the local database. streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) - return + return []synctypes.ClientEvent{}, emptyToken, emptyToken, err } var events []*rstypes.HeaderedEvent @@ -333,11 +334,11 @@ func (r *messagesReq) retrieveEvents() ( // on the ordering), or we've reached a backward extremity. if len(streamEvents) == 0 { if events, err = r.handleEmptyEventsSlice(); err != nil { - return + return []synctypes.ClientEvent{}, emptyToken, emptyToken, err } } else { if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil { - return + return []synctypes.ClientEvent{}, emptyToken, emptyToken, err } } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index b5fd5be8e..0ea48a9d3 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -489,7 +489,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( stateEvents, err := snapshot.CurrentState(ctx, roomID, stateFilter, excludingEventIDs) if err != nil { - return + return jr, err } jr.Summary, err = snapshot.GetRoomSummary(ctx, roomID, device.UserID) @@ -542,7 +542,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, event.EventID()) if err != nil { - return + return jr, err } prevBatch = &types.TopologyToken{ Depth: backwardTopologyPos, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 705707571..537bbbf4a 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -644,7 +644,7 @@ func (d *Database) CreateDevice( for i := 1; i <= 5; i++ { newDeviceID, returnErr = generateDeviceID() if returnErr != nil { - return + return nil, returnErr } returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -653,7 +653,7 @@ func (d *Database) CreateDevice( return err }) if returnErr == nil { - return + return dev, nil } } } From 725ff5567d2a3bc9992b065e72ccabefb595ec1c Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 6 Jun 2023 15:16:55 +0200 Subject: [PATCH 15/35] Make `StrictValidityChecking` a function (#3092) Companion PR to https://github.com/matrix-org/gomatrixserverlib/pull/388 --- federationapi/internal/keys.go | 2 +- federationapi/routing/backfill.go | 2 +- federationapi/routing/join.go | 52 +------ federationapi/routing/leave.go | 8 +- federationapi/routing/send.go | 2 +- federationapi/routing/threepid.go | 4 +- go.mod | 2 +- go.sum | 4 +- internal/caching/cache_serverkeys.go | 2 +- roomserver/api/api.go | 13 +- roomserver/api/query.go | 68 +++++--- roomserver/internal/api.go | 1 + roomserver/internal/perform/perform_join.go | 24 +-- roomserver/internal/query/query.go | 164 ++++++-------------- roomserver/roomserver_test.go | 49 ++---- 15 files changed, 145 insertions(+), 252 deletions(-) diff --git a/federationapi/internal/keys.go b/federationapi/internal/keys.go index 00e78a1c1..a642f3a4b 100644 --- a/federationapi/internal/keys.go +++ b/federationapi/internal/keys.go @@ -170,7 +170,7 @@ func (s *FederationInternalAPI) handleDatabaseKeys( // in that case. If the key isn't valid right now, then by // leaving it in the 'requests' map, we'll try to update the // key using the fetchers in handleFetcherKeys. - if res.WasValidAt(now, true) { + if res.WasValidAt(now, gomatrixserverlib.StrictValiditySignatureCheck) { delete(requests, req) } } diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 9e1595053..552c4eac2 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -95,7 +95,7 @@ func Backfill( } } - // Query the roomserver. + // Query the Roomserver. if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed") return util.JSONResponse{ diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index c6f96375e..2980c2af2 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -15,7 +15,6 @@ package routing import ( - "context" "fmt" "net/http" "sort" @@ -33,53 +32,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) -type JoinRoomQuerier struct { - roomserver api.FederationRoomserverAPI -} - -func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { - return rq.roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) -} - -func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { - return rq.roomserver.InvitePending(ctx, roomID, userID) -} - -func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { - roomInfo, err := rq.roomserver.QueryRoomInfo(ctx, roomID) - if err != nil || roomInfo == nil || roomInfo.IsStub() { - return nil, err - } - - req := api.QueryServerJoinedToRoomRequest{ - ServerName: localServerName, - RoomID: roomID.String(), - } - res := api.QueryServerJoinedToRoomResponse{} - if err = rq.roomserver.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { - util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") - return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) - } - - userJoinedToRoom, err := rq.roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") - return nil, fmt.Errorf("InternalServerError: %w", err) - } - - locallyJoinedUsers, err := rq.roomserver.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") - return nil, fmt.Errorf("InternalServerError: %w", err) - } - - return &gomatrixserverlib.RestrictedRoomJoinInfo{ - LocalServerInRoom: res.RoomExists && res.IsInRoom, - UserJoinedToRoom: userJoinedToRoom, - JoinedUsers: locallyJoinedUsers, - }, nil -} - // MakeJoin implements the /make_join API func MakeJoin( httpReq *http.Request, @@ -142,8 +94,8 @@ func MakeJoin( return event, stateEvents, nil } - roomQuerier := JoinRoomQuerier{ - roomserver: rsAPI, + roomQuerier := api.JoinRoomQuerier{ + Roomserver: rsAPI, } input := gomatrixserverlib.HandleMakeJoinInput{ diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index a767168d8..d7d5b599d 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -291,10 +291,10 @@ func SendLeave( } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, + ServerName: serverName, + Message: redacted, + AtTS: event.OriginServerTS(), + ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck, }} verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) if err != nil { diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 3c8e0cbef..966694541 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -34,7 +34,7 @@ import ( ) const ( - // Event was passed to the roomserver + // Event was passed to the Roomserver MetricsOutcomeOK = "ok" // Event failed to be processed MetricsOutcomeFail = "fail" diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index beeb52495..76a2f3d5a 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -223,7 +223,7 @@ func ExchangeThirdPartyInvite( } } - // Send the event to the roomserver + // Send the event to the Roomserver if err = api.SendEvents( httpReq.Context(), rsAPI, api.KindNew, @@ -324,7 +324,7 @@ func buildMembershipEvent( return nil, errors.New("expecting state tuples for event builder, got none") } - // Ask the roomserver for information about this room + // Ask the Roomserver for information about this room queryReq := api.QueryLatestEventsAndStateRequest{ RoomID: protoEvent.RoomID, StateToFetch: eventsNeeded.Tuples(), diff --git a/go.mod b/go.mod index a20757bbc..a49dfa0c9 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index a1946adaa..79154624a 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 h1:Kh1TNvJDhWN5CdgtICNUC4G0wV2km51LGr46Dvl153A= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e h1:I3Sfr8gZvVtLHOeI8lgc62kgLuzpMhBZ6EQOMyexXEA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index 37e331ab0..7400b868c 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -28,7 +28,7 @@ func (c Caches) GetServerKey( ) (gomatrixserverlib.PublicKeyLookupResult, bool) { key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) val, found := c.ServerKeys.Get(key) - if found && !val.WasValidAt(timestamp, true) { + if found && !val.WasValidAt(timestamp, gomatrixserverlib.StrictValiditySignatureCheck) { // The key wasn't valid at the requested timestamp so don't // return it. The caller will have to work out what to do. c.ServerKeys.Unset(key) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 7cb3379e0..a37ade3a3 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -32,6 +32,16 @@ func (e ErrNotAllowed) Error() string { return e.Err.Error() } +type RestrictedJoinAPI interface { + CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) + InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) + RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) + QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) + QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error + UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) + LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) +} + // RoomserverInputAPI is used to write events to the room server. type RoomserverInternalAPI interface { SyncRoomserverAPI @@ -199,6 +209,7 @@ type UserRoomserverAPI interface { } type FederationRoomserverAPI interface { + RestrictedJoinAPI InputRoomEventsAPI QueryLatestEventsAndStateAPI QueryBulkStateContentAPI @@ -223,7 +234,7 @@ type FederationRoomserverAPI interface { // Query whether a server is allowed to see an event QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error - QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error + QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b33698c82..e741c1402 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/roomserver/types" @@ -351,26 +352,6 @@ type QueryServerBannedFromRoomResponse struct { Banned bool `json:"banned"` } -type QueryRestrictedJoinAllowedRequest struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` -} - -type QueryRestrictedJoinAllowedResponse struct { - // True if the room membership is restricted by the join rule being set to "restricted" - Restricted bool `json:"restricted"` - // True if our local server is joined to all of the allowed rooms specified in the "allow" - // key of the join rule, false if we are missing from some of them and therefore can't - // reliably decide whether or not we can satisfy the join - Resident bool `json:"resident"` - // True if the restricted join is allowed because we found the membership in one of the - // allowed rooms from the join rule, false if not - Allowed bool `json:"allowed"` - // Contains the user ID of the selected user ID that has power to issue invites, this will - // get populated into the "join_authorised_via_users_server" content in the membership - AuthorisedVia string `json:"authorised_via,omitempty"` -} - // MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode. func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) { se := make(map[string]string) @@ -459,6 +440,53 @@ type QueryLeftUsersResponse struct { LeftUsers []string `json:"user_ids"` } +type JoinRoomQuerier struct { + Roomserver RestrictedJoinAPI +} + +func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { + return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) +} + +func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { + return rq.Roomserver.InvitePending(ctx, roomID, userID) +} + +func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { + roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID) + if err != nil || roomInfo == nil || roomInfo.IsStub() { + return nil, err + } + + req := QueryServerJoinedToRoomRequest{ + ServerName: localServerName, + RoomID: roomID.String(), + } + res := QueryServerJoinedToRoomResponse{} + if err = rq.Roomserver.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) + } + + userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + locallyJoinedUsers, err := rq.Roomserver.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + return &gomatrixserverlib.RestrictedRoomJoinInfo{ + LocalServerInRoom: res.RoomExists && res.IsInRoom, + UserJoinedToRoom: userJoinedToRoom, + JoinedUsers: locallyJoinedUsers, + }, nil +} + type MembershipQuerier struct { Roomserver FederationRoomserverAPI } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index ee433f0d2..35b7383a9 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -94,6 +94,7 @@ func NewRoomserverAPI( Cache: caches, IsLocalServerName: dendriteCfg.Global.IsLocalServerName, ServerACLs: serverACLs, + Cfg: dendriteCfg, }, enableMetrics: enableMetrics, // perform-er structs get initialised when we have a federation sender to use diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 34bea5b6d..181a93490 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -372,22 +372,14 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( ctx context.Context, joinReq *rsAPI.PerformJoinRequest, ) (string, error) { - req := &api.QueryRestrictedJoinAllowedRequest{ - UserID: joinReq.UserID, - RoomID: joinReq.RoomIDOrAlias, + roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias) + if err != nil { + return "", err } - res := &api.QueryRestrictedJoinAllowedResponse{} - if err := r.Queryer.QueryRestrictedJoinAllowed(ctx, req, res); err != nil { - return "", fmt.Errorf("r.Queryer.QueryRestrictedJoinAllowed: %w", err) + userID, err := spec.NewUserID(joinReq.UserID, true) + if err != nil { + return "", err } - if !res.Restricted { - return "", nil - } - if !res.Resident { - return "", nil - } - if !res.Allowed { - return "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("the join to room %s was not allowed", joinReq.RoomIDOrAlias)} - } - return res.AuthorisedVia, nil + + return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index effcc90d7..6d898e8ad 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -17,10 +17,11 @@ package query import ( "context" "database/sql" - "encoding/json" "errors" "fmt" + //"github.com/matrix-org/dendrite/roomserver/internal" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" @@ -44,6 +45,42 @@ type Queryer struct { Cache caching.RoomServerCaches IsLocalServerName func(spec.ServerName) bool ServerACLs *acls.ServerACLs + Cfg *config.Dendrite +} + +func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { + roomInfo, err := r.QueryRoomInfo(ctx, roomID) + if err != nil || roomInfo == nil || roomInfo.IsStub() { + return nil, err + } + + req := api.QueryServerJoinedToRoomRequest{ + ServerName: localServerName, + RoomID: roomID.String(), + } + res := api.QueryServerJoinedToRoomResponse{} + if err = r.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) + } + + userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + locallyJoinedUsers, err := r.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + return &gomatrixserverlib.RestrictedRoomJoinInfo{ + LocalServerInRoom: res.RoomExists && res.IsInRoom, + UserJoinedToRoom: userJoinedToRoom, + JoinedUsers: locallyJoinedUsers, + }, nil } // QueryLatestEventsAndState implements api.RoomserverInternalAPI @@ -906,131 +943,20 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse } // nolint:gocyclo -func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse) error { +func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { // Look up if we know anything about the room. If it doesn't exist // or is a stub entry then we can't do anything. - roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) + roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) if err != nil { - return fmt.Errorf("r.DB.RoomInfo: %w", err) + return "", fmt.Errorf("r.DB.RoomInfo: %w", err) } if roomInfo == nil || roomInfo.IsStub() { - return nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID) + return "", nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID) } verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) if err != nil { - return err + return "", err } - // If the room version doesn't allow restricted joins then don't - // try to process any further. - allowRestrictedJoins := verImpl.MayAllowRestrictedJoinsInEventAuth() - if !allowRestrictedJoins { - return nil - } - // Start off by populating the "resident" flag in the response. If we - // come across any rooms in the request that are missing, we will unset - // the flag. - res.Resident = true - // Get the join rules to work out if the join rule is "restricted". - joinRulesEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, spec.MRoomJoinRules, "") - if err != nil { - return fmt.Errorf("r.DB.GetStateEvent: %w", err) - } - if joinRulesEvent == nil { - return nil - } - var joinRules gomatrixserverlib.JoinRuleContent - if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - // If the join rule isn't "restricted" or "knock_restricted" then there's nothing more to do. - res.Restricted = joinRules.JoinRule == spec.Restricted || joinRules.JoinRule == spec.KnockRestricted - if !res.Restricted { - return nil - } - // If the user is already invited to the room then the join is allowed - // but we don't specify an authorised via user, since the event auth - // will allow the join anyway. - var pending bool - if pending, _, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil { - return fmt.Errorf("helpers.IsInvitePending: %w", err) - } else if pending { - res.Allowed = true - return nil - } - // We need to get the power levels content so that we can determine which - // users in the room are entitled to issue invites. We need to use one of - // these users as the authorising user. - powerLevelsEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, spec.MRoomPowerLevels, "") - if err != nil { - return fmt.Errorf("r.DB.GetStateEvent: %w", err) - } - powerLevels, err := powerLevelsEvent.PowerLevels() - if err != nil { - return fmt.Errorf("unable to get powerlevels: %w", err) - } - // Step through the join rules and see if the user matches any of them. - for _, rule := range joinRules.Allow { - // We only understand "m.room_membership" rules at this point in - // time, so skip any rule that doesn't match those. - if rule.Type != spec.MRoomMembership { - continue - } - // See if the room exists. If it doesn't exist or if it's a stub - // room entry then we can't check memberships. - targetRoomInfo, err := r.DB.RoomInfo(ctx, rule.RoomID) - if err != nil || targetRoomInfo == nil || targetRoomInfo.IsStub() { - res.Resident = false - continue - } - // First of all work out if *we* are still in the room, otherwise - // it's possible that the memberships will be out of date. - isIn, err := r.DB.GetLocalServerInRoom(ctx, targetRoomInfo.RoomNID) - if err != nil || !isIn { - // If we aren't in the room, we can no longer tell if the room - // memberships are up-to-date. - res.Resident = false - continue - } - // At this point we're happy that we are in the room, so now let's - // see if the target user is in the room. - _, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID) - if err != nil { - continue - } - // If the user is not in the room then we will skip them. - if !isIn { - continue - } - // The user is in the room, so now we will need to authorise the - // join using the user ID of one of our own users in the room. Pick - // one. - joinNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, targetRoomInfo.RoomNID, true, true) - if err != nil || len(joinNIDs) == 0 { - // There should always be more than one join NID at this point - // because we are gated behind GetLocalServerInRoom, but y'know, - // sometimes strange things happen. - continue - } - // For each of the joined users, let's see if we can get a valid - // membership event. - for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, roomInfo.RoomVersion, []types.EventNID{joinNID}) - if err != nil || len(events) != 1 { - continue - } - event := events[0] - if event.Type() != spec.MRoomMember || event.StateKey() == nil { - continue // shouldn't happen - } - // Only users that have the power to invite should be chosen. - if powerLevels.UserLevel(*event.StateKey()) < powerLevels.Invite { - continue - } - res.Resident = true - res.Allowed = true - res.AuthorisedVia = *event.StateKey() - return nil - } - } - return nil + + return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index d19ebebe4..11a0f5817 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -598,16 +598,15 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { testCases := []struct { name string prepareRoomFunc func(t *testing.T) *test.Room - wantResponse api.QueryRestrictedJoinAllowedResponse + wantResponse string + wantError bool }{ { name: "public room unrestricted", prepareRoomFunc: func(t *testing.T) *test.Room { return test.NewRoom(t, alice) }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - }, + wantResponse: "", }, { name: "room version without restrictions", @@ -624,10 +623,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey("")) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - Restricted: true, - }, + wantError: true, }, { name: "knock_restricted", @@ -638,10 +634,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey("")) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - Restricted: true, - }, + wantError: true, }, { name: "restricted with pending invite", // bob should be allowed to join @@ -655,11 +648,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey(bob.ID)) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - Restricted: true, - Allowed: true, - }, + wantResponse: "", }, { name: "restricted with allowed room_id, but missing room", // bob should not be allowed to join, as we don't know about the room @@ -680,9 +669,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey(bob.ID)) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Restricted: true, - }, + wantError: true, }, { name: "restricted with allowed room_id", // bob should be allowed to join, as we know about the room @@ -703,12 +690,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey(bob.ID)) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - Restricted: true, - Allowed: true, - AuthorisedVia: alice.ID, - }, + wantResponse: alice.ID, }, } @@ -738,16 +720,17 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { t.Errorf("failed to send events: %v", err) } - req := api.QueryRestrictedJoinAllowedRequest{ - UserID: bob.ID, - RoomID: testRoom.ID, + roomID, _ := spec.NewRoomID(testRoom.ID) + userID, _ := spec.NewUserID(bob.ID, true) + got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID) + if tc.wantError && err == nil { + t.Fatal("expected error, got none") } - res := api.QueryRestrictedJoinAllowedResponse{} - if err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), &req, &res); err != nil { + if !tc.wantError && err != nil { t.Fatal(err) } - if !reflect.DeepEqual(tc.wantResponse, res) { - t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, res) + if !reflect.DeepEqual(tc.wantResponse, got) { + t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, got) } }) } From 7a1fd7f512ce06a472a2051ee63eae4a270eb71a Mon Sep 17 00:00:00 2001 From: devonh Date: Tue, 6 Jun 2023 20:55:18 +0000 Subject: [PATCH 16/35] PDU Sender split (#3100) Initial cut of splitting PDU Sender into SenderID & looking up UserID where required. --- appservice/consumers/roomserver.go | 12 ++++- clientapi/routing/directory.go | 30 +++++++++++- clientapi/routing/redaction.go | 2 +- clientapi/routing/sendevent.go | 4 +- clientapi/routing/state.go | 21 +++++++-- cmd/resolve-state/main.go | 5 +- federationapi/federationapi_test.go | 4 ++ federationapi/internal/perform.go | 35 ++++++++------ federationapi/routing/invite.go | 6 +++ federationapi/routing/join.go | 24 ++++++---- federationapi/routing/leave.go | 13 ++++-- go.mod | 4 +- go.sum | 8 ++-- internal/pushrules/evaluate.go | 16 +++++-- internal/pushrules/evaluate_test.go | 17 ++++--- internal/transactionrequest.go | 4 +- internal/transactionrequest_test.go | 8 ++++ roomserver/api/alias.go | 2 +- roomserver/api/api.go | 12 +++++ roomserver/api/query.go | 4 +- roomserver/internal/alias.go | 21 +++++---- roomserver/internal/helpers/auth.go | 4 +- roomserver/internal/input/input_events.go | 32 +++++++++---- .../internal/input/input_events_test.go | 2 +- roomserver/internal/input/input_missing.go | 24 +++++++--- roomserver/internal/perform/perform_admin.go | 8 +++- .../internal/perform/perform_backfill.go | 12 +++-- .../internal/perform/perform_create_room.go | 4 +- roomserver/internal/perform/perform_invite.go | 12 +++-- .../internal/perform/perform_upgrade.go | 10 ++-- roomserver/internal/query/query.go | 30 ++++++++++-- roomserver/producers/roomevent.go | 2 +- roomserver/state/state.go | 9 +++- roomserver/storage/interface.go | 5 ++ .../storage/shared/membership_updater.go | 2 +- roomserver/storage/shared/room_updater.go | 5 ++ roomserver/storage/shared/storage.go | 28 +++++++++-- setup/mscs/msc2836/msc2836.go | 8 ++-- setup/mscs/msc2836/msc2836_test.go | 4 ++ setup/mscs/msc2946/msc2946.go | 2 +- syncapi/consumers/roomserver.go | 2 +- syncapi/routing/context.go | 23 ++++++++-- syncapi/routing/getevent.go | 7 ++- syncapi/routing/memberships.go | 6 ++- syncapi/routing/messages.go | 12 +++-- syncapi/routing/relations.go | 7 ++- syncapi/routing/routing.go | 2 +- syncapi/routing/search.go | 46 +++++++++++++------ syncapi/routing/search_test.go | 10 +++- .../postgres/current_room_state_table.go | 2 +- .../postgres/output_room_events_table.go | 2 +- syncapi/storage/shared/storage_consumer.go | 21 ++++++++- .../sqlite3/current_room_state_table.go | 2 +- .../sqlite3/output_room_events_table.go | 2 +- syncapi/streams/stream_invite.go | 12 ++++- syncapi/streams/stream_pdu.go | 38 ++++++++++----- syncapi/streams/streams.go | 1 + syncapi/syncapi_test.go | 4 ++ syncapi/synctypes/clientevent.go | 13 ++++-- syncapi/synctypes/clientevent_test.go | 17 +++++-- syncapi/types/types.go | 4 +- syncapi/types/types_test.go | 12 ++++- test/room.go | 6 ++- userapi/consumers/roomserver.go | 43 ++++++++++++----- userapi/consumers/roomserver_test.go | 11 ++++- userapi/util/notify_test.go | 9 +++- 66 files changed, 580 insertions(+), 189 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index c02d90404..06625ad7e 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -181,7 +181,9 @@ func (s *OutputRoomEventConsumer) sendEvents( // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll), + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), }, ) if err != nil { @@ -233,10 +235,16 @@ func (s *appserviceState) backoffAndPause(err error) error { // // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { + user := "" + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil { + user = userID.String() + } + switch { case appservice.URL == "": return false - case appservice.IsInterestedInUserID(event.Sender()): + case appservice.IsInterestedInUserID(user): return true case appservice.IsInterestedInRoomID(event.RoomID()): return true diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index c786f8cc4..0c842e6a5 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -215,9 +215,35 @@ func RemoveLocalAlias( alias string, rsAPI roomserverAPI.ClientRoomserverAPI, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "UserID for device is invalid"}, + } + } + + roomIDReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: alias} + roomIDRes := roomserverAPI.GetRoomIDForAliasResponse{} + err = rsAPI.GetRoomIDForAlias(req.Context(), &roomIDReq, &roomIDRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"}, + } + } + queryReq := roomserverAPI.RemoveRoomAliasRequest{ - Alias: alias, - UserID: device.UserID, + Alias: alias, + SenderID: deviceSenderID, } var queryRes roomserverAPI.RemoveRoomAliasResponse if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 883126423..e94c7748e 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -76,7 +76,7 @@ func SendRedaction( // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - allowedToRedact := ev.Sender() == device.UserID + allowedToRedact := ev.SenderID() == device.UserID // TODO: Should replace device.UserID with device...PerRoomKey if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 1a2e25c9d..8b09f399a 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -331,7 +331,9 @@ func generateSendEvent( stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider); err != nil { + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 319f4eba5..13f308998 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -140,9 +140,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // use the result of the previous QueryLatestEventsAndState response // to find the state event, if provided. for _, ev := range stateRes.StateEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), ) } } else { @@ -162,9 +167,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } } for _, ev := range stateAfterRes.StateEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), ) } } @@ -334,8 +344,13 @@ func OnIncomingStateTypeRequest( } } + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll), + ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), } var res interface{} diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 3a4255bae..360403094 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -18,6 +18,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // This is a utility for inspecting state snapshots and running state resolution @@ -182,7 +183,9 @@ func main() { fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { panic(err) diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index beb648a48..a97bcdeab 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,6 +36,10 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + // PerformJoin will call this function func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { if f.inputRoomEvents == nil { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index ed800d03a..2d59d0f93 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -156,15 +156,20 @@ func (r *FederationInternalAPI) performJoinUsingServer( } joinInput := gomatrixserverlib.PerformJoinInput{ - UserID: user, - RoomID: room, - ServerName: serverName, - Content: content, - Unsigned: unsigned, - PrivateKey: r.cfg.Matrix.PrivateKey, - KeyID: r.cfg.Matrix.KeyID, - KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName), + UserID: user, + RoomID: room, + ServerName: serverName, + Content: content, + Unsigned: unsigned, + PrivateKey: r.cfg.Matrix.PrivateKey, + KeyID: r.cfg.Matrix.KeyID, + KeyRing: r.keyRing, + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, } response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) @@ -358,8 +363,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() + userIDProvider := func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( - ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName), + ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider, ) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) @@ -509,7 +517,7 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', event.Sender()) + inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { return nil, err } @@ -542,7 +550,7 @@ func (r *FederationInternalAPI) SendInvite( return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, inviter.Domain(), destination, inviteReq) if err != nil { return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } @@ -635,6 +643,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error { func federatedEventProvider( ctx context.Context, federation fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, + userIDForSender spec.UserIDForSender, ) gomatrixserverlib.EventProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. @@ -684,7 +693,7 @@ func federatedEventProvider( } // Check the signatures of the event. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil { return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 78a09d949..d792335b9 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,6 +95,9 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) if jsonErr != nil { @@ -185,6 +188,9 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) if jsonErr != nil { diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 2980c2af2..9da059189 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -99,15 +99,18 @@ func MakeJoin( } input := gomatrixserverlib.HandleMakeJoinInput{ - Context: httpReq.Context(), - UserID: userID, - RoomID: roomID, - RoomVersion: roomVersion, - RemoteVersions: remoteVersions, - RequestOrigin: request.Origin(), - LocalServerName: cfg.Matrix.ServerName, - LocalServerInRoom: res.RoomExists && res.IsInRoom, - RoomQuerier: &roomQuerier, + Context: httpReq.Context(), + UserID: userID, + RoomID: roomID, + RoomVersion: roomVersion, + RemoteVersions: remoteVersions, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + LocalServerInRoom: res.RoomExists && res.IsInRoom, + RoomQuerier: &roomQuerier, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, BuildEventTemplate: createJoinTemplate, } response, internalErr := gomatrixserverlib.HandleMakeJoin(input) @@ -202,6 +205,9 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, } response, joinErr := gomatrixserverlib.HandleSendJoin(input) switch e := joinErr.(type) { diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index d7d5b599d..30e99c4f7 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -95,6 +95,9 @@ func MakeLeave( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, } response, internalErr := gomatrixserverlib.HandleMakeLeave(input) @@ -213,7 +216,7 @@ func SendLeave( JSON: spec.BadJSON("No state key was provided in the leave event."), } } - if !event.StateKeyEquals(event.Sender()) { + if !event.StateKeyEquals(event.SenderID()) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("Event state key must match the event sender."), @@ -223,13 +226,13 @@ func SendLeave( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - var serverName spec.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID()) + if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender of the join is invalid"), } - } else if serverName != request.Origin() { + } else if sender.Domain() != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender does not match the server that originated the request"), @@ -291,7 +294,7 @@ func SendLeave( } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, + ServerName: sender.Domain(), Message: redacted, AtTS: event.OriginServerTS(), ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck, diff --git a/go.mod b/go.mod index a49dfa0c9..10551f702 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e + github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -34,7 +34,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 - github.com/sirupsen/logrus v1.9.2 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 79154624a..3ec1c115c 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e h1:I3Sfr8gZvVtLHOeI8lgc62kgLuzpMhBZ6EQOMyexXEA= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 h1:6SixhMmB5Ir10xUJ6zh3A4NBxSaZCSz2s5U63Wg0eEU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y= -github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 7c98efd30..da33d3862 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // A RuleSetEvaluator encapsulates context to evaluate an event @@ -53,7 +54,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat // MatchEvent returns the first matching rule. Returns nil if there // was no match rule. -func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, error) { +func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) (*Rule, error) { // TODO: server-default rules have lower priority than user rules, // but they are stored together with the user rules. It's a bit // unclear what the specification (11.14.1.4 Predefined rules) @@ -68,7 +69,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err if rule.Default != defRules { continue } - ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) + ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec, userIDForSender) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err return nil, nil } -func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) { +func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext, userIDForSender spec.UserIDForSender) (bool, error) { if !rule.Enabled { return false, nil } @@ -113,7 +114,12 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati return rule.RuleID == event.RoomID(), nil case SenderKind: - return rule.RuleID == event.Sender(), nil + userID := "" + sender, err := userIDForSender(event.RoomID(), event.SenderID()) + if err == nil { + userID = sender.String() + } + return rule.RuleID == userID, nil default: return false, nil @@ -143,7 +149,7 @@ func conditionMatches(cond *Condition, event gomatrixserverlib.PDU, ec Evaluatio return cmp(n), nil case SenderNotificationPermissionCondition: - return ec.HasPowerLevel(event.Sender(), cond.Key) + return ec.HasPowerLevel(event.SenderID(), cond.Key) default: return false, nil diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 5045a864e..34c1436f4 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -5,8 +5,13 @@ import ( "github.com/google/go-cmp/cmp" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestRuleSetEvaluatorMatchEvent(t *testing.T) { ev := mustEventFromJSON(t, `{}`) defaultEnabled := &Rule{ @@ -45,7 +50,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) { for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet) - got, err := rse.MatchEvent(tst.Event) + got, err := rse.MatchEvent(tst.Event, UserIDForSender) if err != nil { t.Fatalf("MatchEvent failed: %v", err) } @@ -82,15 +87,15 @@ func TestRuleMatches(t *testing.T) { {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, - {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, - {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, + {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, + {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, - {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, - {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { - got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) + got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil, UserIDForSender) if err != nil { t.Fatalf("ruleMatches failed: %v", err) } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index c9d321f25..0bbe0720c 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,7 +167,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = fclient.PDUResult{ Error: err.Error(), diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index fb30d410e..6f3ce0b3b 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,6 +70,10 @@ type FakeRsAPI struct { bannedFromRoom bool } +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (r *FakeRsAPI) QueryRoomVersionForRoom( ctx context.Context, roomID string, @@ -638,6 +642,10 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *rsAPI.InputRoomEventsRequest, diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 37892a44a..1b9475404 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -62,7 +62,7 @@ type GetAliasesForRoomIDResponse struct { // RemoveRoomAliasRequest is a request to RemoveRoomAlias type RemoveRoomAliasRequest struct { // ID of the user removing the alias - UserID string `json:"user_id"` + SenderID string `json:"user_id"` // The room alias to remove Alias string `json:"alias"` } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a37ade3a3..d61a05534 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -49,6 +49,7 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + QuerySenderIDAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -75,6 +76,11 @@ type InputRoomEventsAPI interface { ) } +type QuerySenderIDAPI interface { + QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) + QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) +} + // Query the latest events and state for a room from the room server. type QueryLatestEventsAndStateAPI interface { QueryLatestEventsAndState(ctx context.Context, req *QueryLatestEventsAndStateRequest, res *QueryLatestEventsAndStateResponse) error @@ -102,6 +108,7 @@ type QueryEventsAPI interface { type SyncRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine @@ -142,6 +149,7 @@ type SyncRoomserverAPI interface { } type AppserviceRoomserverAPI interface { + QuerySenderIDAPI // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // which room to use by querying the first events roomID. QueryEventsByID( @@ -168,6 +176,7 @@ type ClientRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI QueryEventsAPI + QuerySenderIDAPI QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error @@ -200,6 +209,7 @@ type ClientRoomserverAPI interface { } type UserRoomserverAPI interface { + QuerySenderIDAPI QueryLatestEventsAndStateAPI KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error @@ -213,6 +223,8 @@ type FederationRoomserverAPI interface { InputRoomEventsAPI QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI + // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index e741c1402..d79dcebbb 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -491,10 +491,10 @@ type MembershipQuerier struct { Roomserver FederationRoomserverAPI } -func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { req := QueryMembershipForUserRequest{ RoomID: roomID.String(), - UserID: userID.String(), + UserID: string(senderID), } res := QueryMembershipForUserResponse{} err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 52b90cf4e..dcfb26b8e 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -119,11 +119,6 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID) - if err != nil { - return err - } - roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -134,13 +129,19 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return nil } + sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) + if err != nil { + return fmt.Errorf("r.QueryUserIDForSender: %w", err) + } + virtualHost := sender.Domain() + response.Found = true creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) } - if creatorID != request.UserID { + if creatorID != request.SenderID { var plEvent *types.HeaderedEvent var pls *gomatrixserverlib.PowerLevelContent @@ -154,7 +155,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return fmt.Errorf("plEvent.PowerLevels: %w", err) } - if pls.UserLevel(request.UserID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { + if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { response.Removed = false return nil } @@ -172,9 +173,9 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - sender := request.UserID - if request.UserID != ev.Sender() { - sender = ev.Sender() + sender := request.SenderID + if request.SenderID != ev.SenderID() { + sender = ev.SenderID() } _, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 7ec0892e4..932ce6155 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -76,7 +76,9 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { // return true, nil return true, err } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 386083f6e..764bdfe2c 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,9 +128,13 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { - return fmt.Errorf("event has invalid sender %q", input.Event.Sender()) + return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) + } + senderDomain := spec.ServerName("") + if sender != nil { + senderDomain = sender.Domain() } // If we already know about this outlier and it hasn't been rejected @@ -193,7 +197,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { + // Only perform this check if the sender mxid_mapping can be resolved. + // Don't fail processing the event if we have no mxid_maping. + if sender != nil && senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) delete(servers, senderDomain) } @@ -276,7 +282,9 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { isRejected = true rejectionErr = err logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) @@ -493,7 +501,7 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { oldRoomID := event.RoomID() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str - return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.SenderID()) } // processStateBefore works out what the state is before the event and @@ -579,7 +587,9 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return } @@ -690,7 +700,9 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue nextAuthEvent } @@ -706,7 +718,9 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } @@ -828,11 +842,13 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r continue } + // TODO: pseudoIDs: get userID for room using state key (which is now senderID) localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) if err != nil { continue } + // TODO: pseudoIDs: query account by state key (which is now senderID) accountRes := &userAPI.QueryAccountByLocalpartResponse{} if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ Localpart: localpart, diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 568038132..0ba7d19f5 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { t.Fatalf("event should not be allowed, but it was") } } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 10486138d..ac0670fc3 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return nil, err } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) @@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } missingEvents = append(missingEvents, t.cacheAndReturn(ev)) @@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil) + }, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return nil, err } @@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } return t.cacheAndReturn(event), nil } -func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error { +func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error { authUsingState := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents { err := authUsingState.AddEvent(stateEvents[i]) @@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli return err } } - return gomatrixserverlib.Allowed(e, &authUsingState) + return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender) } func (t *missingStateReq) hadEvent(eventID string) { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 575525e21..ca736cb65 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -262,13 +262,17 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } stateEventMap[stateEvent.EventID()] = stateEvent diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index fb579f03a..0f743f4e4 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) // Only return an error if we really couldn't get any events. if err != nil && len(events) == 0 { @@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") continue @@ -484,8 +488,8 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil { - serverSet[senderDomain] = true + if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { + serverSet[sender.Domain()] = true } } var servers []spec.ServerName diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 41194832d..897bd3a0e 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -308,7 +308,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return c.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 1930b5ace..e8e20ede2 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership( ) ([]api.OutputEvent, error) { var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - _, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey()) + + userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey()) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } - isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } @@ -125,9 +126,9 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := spec.NewUserID(event.Sender(), true) + sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { - return spec.InvalidParam("The user ID is invalid") + return spec.InvalidParam("The sender user ID is invalid") } if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} @@ -155,6 +156,9 @@ func (r *Inviter) PerformInvite( StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, StateQuerier: &QueryState{r.DB}, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) if err != nil { diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index ff4a6a1dc..8c0df1c46 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -176,7 +176,7 @@ func moveLocalAliases(ctx context.Context, } for _, alias := range aliasRes.Aliases { - removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias} + removeAliasReq := api.RemoveRoomAliasRequest{SenderID: userID, Alias: alias} removeAliasRes := api.RemoveRoomAliasResponse{} if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { return fmt.Errorf("Failed to remove old room alias: %w", err) @@ -484,7 +484,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user } - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) } @@ -567,7 +569,9 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil { + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 6d898e8ad..707e95b2a 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -159,7 +159,9 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) @@ -386,7 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) + sender := spec.UserID{} + userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -435,7 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) + sender := spec.UserID{} + userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -625,7 +637,9 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return err @@ -960,3 +974,11 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) } + +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { + return r.DB.GetSenderIDForUser(ctx, roomID, userID) +} + +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) +} diff --git a/roomserver/producers/roomevent.go b/roomserver/producers/roomevent.go index febe8ddf4..165304d49 100644 --- a/roomserver/producers/roomevent.go +++ b/roomserver/producers/roomevent.go @@ -60,7 +60,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu "adds_state": len(update.NewRoomEvent.AddsStateEventIDs), "removes_state": len(update.NewRoomEvent.RemovesStateEventIDs), "send_as_server": update.NewRoomEvent.SendAsServer, - "sender": update.NewRoomEvent.Event.Sender(), + "sender": update.NewRoomEvent.Event.SenderID(), }) if update.NewRoomEvent.Event.StateKey() != nil { logger = logger.WithField("state_key", *update.NewRoomEvent.Event.StateKey()) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index f38d8f96a..3131cbff2 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -24,6 +24,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -43,6 +44,7 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type StateResolution struct { @@ -945,7 +947,9 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) + }) // Map from the full events back to numeric state entries. for _, resolvedEvent := range resolvedEvents { @@ -1057,6 +1061,9 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, + func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) + }, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7d22df008..2d007bed5 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -166,6 +166,10 @@ type Database interface { GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) + // GetKnownUsers tries to obtain the current mxid for a given user. + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) + // GetKnownUsers tries to obtain the current senderID for a given user. + GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -211,6 +215,7 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index f9c889cb1..105e61df6 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event * var inserted bool // Did the query result in a membership change? var retired []string // Did we retire any updates in the process? return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.SenderID()) if err != nil { return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 70672a33e..735001383 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -250,3 +251,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } + +func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return u.d.GetUserIDForSender(ctx, roomID, senderID) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cefa58a3d..406d7cf1c 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -988,8 +988,18 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } - _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) - _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) + // TODO: Don't hack senderID into userID here (pseudoIDs) + sender1Domain := "" + sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true) + if err1 == nil { + sender1Domain = string(sender1.Domain()) + } + // TODO: Don't hack senderID into userID here (pseudoIDs) + sender2Domain := "" + sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true) + if err2 == nil { + sender2Domain = string(sender2.Domain()) + } var powerlevels *gomatrixserverlib.PowerLevelContent powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) if err != nil { @@ -997,9 +1007,9 @@ func (d *EventDatabase) MaybeRedactEvent( } switch { - case powerlevels.UserLevel(redactionEvent.Sender()) >= powerlevels.Redact: + case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1 == sender2: + case sender1Domain == sender2Domain: // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true @@ -1514,6 +1524,16 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } +func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + // TODO: Use real logic once DB for pseudoIDs is in place + return spec.NewUserID(senderID, true) +} + +func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { + // TODO: Use real logic once DB for pseudoIDs is in place + return userID.String(), nil +} + // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index f468b048a..5ce3b430b 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -92,9 +92,11 @@ type MSC2836EventRelationshipsResponse struct { ParsedAuthChain []gomatrixserverlib.PDU } -func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { +func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll), + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), Limited: res.Limited, NextBatch: res.NextBatch, } @@ -187,7 +189,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP return util.JSONResponse{ Code: 200, - JSON: toClientResponse(res), + JSON: toClientResponse(req.Context(), res, rsAPI), } } } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 2c6f63d45..c463fd72b 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -525,6 +525,10 @@ type testRoomserverAPI struct { events map[string]*types.HeaderedEvent } +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { for _, eventID := range req.EventIDs { ev := r.events[eventID] diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 291e0f3b2..f380d3d4f 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -730,7 +730,7 @@ func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent { Type: ev.Type(), StateKey: *ev.StateKey(), Content: ev.Content(), - Sender: ev.Sender(), + Sender: ev.SenderID(), OriginServerTS: ev.OriginServerTS(), } } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 56285dbf4..c08364658 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -523,7 +523,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) prev := types.PrevEventRef{ PrevContent: prevEvent.Content(), ReplacesState: prevEvent.EventID(), - PrevSender: prevEvent.Sender(), + PrevSender: prevEvent.SenderID(), } event.PDU, err = event.SetUnsigned(prev) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index ac17d39d2..27e99a357 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -193,14 +193,20 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll) + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) newState := state if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll) + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { logrus.WithError(err).Error("unable to load membership events") @@ -211,12 +217,19 @@ func Context( } } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll) + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll), + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), } if len(response.State) > filter.Limit { diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 0d3d412f6..63df7e837 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -101,8 +101,13 @@ func GetEvent( } } + sender := spec.UserID{} + senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID()) + if err == nil && senderUserID != nil { + sender = *senderUserID + } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 7d2e137d3..9c2319dd9 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -144,7 +144,7 @@ func GetMemberships( JSON: spec.InternalServerError{}, } } - res.Joined[ev.Sender()] = joinedMember(content) + res.Joined[ev.SenderID()] = joinedMember(content) } return util.JSONResponse{ Code: http.StatusOK, @@ -153,6 +153,8 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll)}, + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + })}, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index aeaec699b..879739d00 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -241,7 +241,7 @@ func OnIncomingMessagesRequest( device: device, } - clientEvents, start, end, err := mReq.retrieveEvents() + clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") return util.JSONResponse{ @@ -273,7 +273,9 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...) + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + })...) } // If we didn't return any events, set the end to an empty string, so it will be omitted @@ -310,7 +312,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api. // homeserver in the room for older events. // Returns an error if there was an issue talking to the database or with the // remote homeserver. -func (r *messagesReq) retrieveEvents() ( +func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserverAPI) ( clientEvents []synctypes.ClientEvent, start, end types.TopologyToken, err error, ) { @@ -383,7 +385,9 @@ func (r *messagesReq) retrieveEvents() ( "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll), start, end, err + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), start, end, err } func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) { diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 8374bf5b0..f21c684c8 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -114,9 +114,14 @@ func Relations( // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), ) } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 9ad0c0476..8542c0b73 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -171,7 +171,7 @@ func Setup( nb := req.FormValue("next_batch") nextBatch = &nb } - return Search(req, device, syncDB, fts, nextBatch) + return Search(req, device, syncDB, fts, nextBatch, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index b7191873e..9cf3eabe2 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -38,7 +39,7 @@ import ( ) // nolint:gocyclo -func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse { +func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string, rsAPI roomserverAPI.SyncRoomserverAPI) util.JSONResponse { start := time.Now() var ( searchReq SearchRequest @@ -204,11 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - profile, ok := knownUsersProfiles[event.Sender()] + userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + if queryErr != nil { + logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + continue + } + + profile, ok := knownUsersProfiles[userID.String()] if !ok { - stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.Sender()) - if err != nil { - logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") + stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) + if stateErr != nil { + logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue } if stateEvent == nil { @@ -218,21 +225,30 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, } - knownUsersProfiles[event.Sender()] = profile + knownUsersProfiles[userID.String()] = profile } - profileInfos[ev.Sender()] = profile + profileInfos[userID.String()] = profile } + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } results = append(results, Result{ Context: SearchContextResponse{ - Start: startToken.String(), - End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync), - ProfileInfo: profileInfos, + Start: startToken.String(), + End: endToken.String(), + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }), + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }), + ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll), + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) @@ -247,7 +263,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync) + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }) } } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 1cc95a873..b36be8238 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -9,6 +10,7 @@ import ( "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" + rsapi "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -21,6 +23,12 @@ import ( "github.com/stretchr/testify/assert" ) +type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } + +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestSearch(t *testing.T) { alice := test.NewUser(t) aliceDevice := userapi.Device{UserID: alice.ID} @@ -247,7 +255,7 @@ func TestSearch(t *testing.T) { assert.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/", reqBody) - res := Search(req, tc.device, db, fts, tc.from) + res := Search(req, tc.device, db, fts, tc.from, &FakeSyncRoomserverAPI{}) if !tc.wantOK && !res.Is2xx() { return } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 0cc963731..bfe5e9bdd 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -343,7 +343,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.Sender(), + event.SenderID(), containsURL, *event.StateKey(), headeredJSON, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3aadbccf8..e068afab1 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -407,7 +407,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.Sender(), + event.SenderID(), containsURL, pq.StringArray(addState), pq.StringArray(removeState), diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index ecfd418fc..17a6a69c3 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -195,7 +195,21 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea for i := 0; i < len(in); i++ { out[i] = in[i].HeaderedEvent if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + deviceSenderID, err := d.getSenderIDForUser(in[i].RoomID(), *userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { err := out[i].SetUnsignedField( "transaction_id", in[i].TransactionID.TransactionID, ) @@ -210,6 +224,11 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea return out } +func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { // nolint + // TODO: Repalce with actual logic for pseudoIDs + return userID.String(), nil +} + // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 1b8632eb6..e432e483b 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -342,7 +342,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.Sender(), + event.SenderID(), containsURL, *event.StateKey(), headeredJSON, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index d63e76067..5a47aec44 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -348,7 +348,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.Sender(), + event.SenderID(), containsURL, string(addStateJSON), string(removeStateJSON), diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index becd863a9..a8b0a7b66 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -10,6 +10,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" @@ -17,6 +18,7 @@ import ( type InviteStreamProvider struct { DefaultStreamProvider + rsAPI api.SyncRoomserverAPI } func (p *InviteStreamProvider) Setup( @@ -62,11 +64,17 @@ func (p *InviteStreamProvider) IncrementalSync( } for roomID, inviteEvent := range invites { + user := spec.UserID{} + sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) + if err == nil && sender != nil { + user = *sender + } + // skip ignored user events - if _, ok := req.IgnoredUsers.List[inviteEvent.Sender()]; ok { + if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent) + ir := types.NewInviteResponse(inviteEvent, user) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 0ea48a9d3..8f83a0896 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -376,20 +376,28 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Join[delta.RoomID] = jr case spec.Peek: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Peek[delta.RoomID] = jr case spec.Leave: @@ -398,11 +406,15 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Leave[delta.RoomID] = lr } @@ -425,7 +437,7 @@ func applyHistoryVisibilityFilter( for _, ev := range recentEvents { if ev.StateKey() != nil { stateTypes = append(stateTypes, ev.Type()) - senders = append(senders, ev.Sender()) + senders = append(senders, ev.SenderID()) } } @@ -552,11 +564,15 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) return jr, nil } @@ -577,8 +593,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( // Add all users the client doesn't know about yet to a list for _, event := range timelineEvents { // Membership is not yet cached, add it to the list - if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok { - timelineUsers[event.Sender()] = struct{}{} + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.SenderID()); !ok { + timelineUsers[event.SenderID()] = struct{}{} } } // Preallocate with the same amount, even if it will end up with fewer values diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index a35491acf..f25bc978f 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -45,6 +45,7 @@ func NewSyncStreamProviders( }, InviteStreamProvider: &InviteStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, + rsAPI: rsAPI, }, SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index bc766e663..78c857ab9 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,6 +40,10 @@ type syncRoomserverAPI struct { rooms []*test.Room } +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { var room *test.Room for _, r := range s.rooms { diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index c722fe60a..66fb1d01f 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -44,22 +44,27 @@ type ClientEvent struct { } // ToClientEvents converts server events to client events. -func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) []ClientEvent { +func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) []ClientEvent { evs := make([]ClientEvent, 0, len(serverEvs)) for _, se := range serverEvs { if se == nil { continue // TODO: shouldn't happen? } - evs = append(evs, ToClientEvent(se, format)) + sender := spec.UserID{} + userID, err := userIDForSender(se.RoomID(), se.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + evs = append(evs, ToClientEvent(se, format, sender)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent { +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { ce := ClientEvent{ Content: spec.RawJSON(se.Content()), - Sender: se.Sender(), + Sender: sender.String(), Type: se.Type(), StateKey: se.StateKey(), Unsigned: spec.RawJSON(se.Unsigned()), diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index b914e64f1..341795081 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func TestToClientEvent(t *testing.T) { // nolint: gocyclo @@ -43,7 +44,11 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatAll) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + ce := ToClientEvent(ev, FormatAll, *userID) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -62,8 +67,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) } - if ce.Sender != ev.Sender() { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.Sender(), ce.Sender) + if ce.Sender != userID.String() { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender) } j, err := json.Marshal(ce) if err != nil { @@ -98,7 +103,11 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatSync) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + ce := ToClientEvent(ev, FormatSync, *userID) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 22c27fea5..526a120d0 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -539,7 +539,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse { // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 8e0448fe7..a79ce5417 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -8,8 +8,13 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ "s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(), @@ -56,7 +61,12 @@ func TestNewInviteResponse(t *testing.T) { t.Fatal(err) } - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}) + sender, err := spec.NewUserID("@neilalexander:matrix.org", true) + if err != nil { + t.Fatal(err) + } + + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/test/room.go b/test/room.go index 852e31533..4cdb73aa3 100644 --- a/test/room.go +++ b/test/room.go @@ -39,6 +39,10 @@ var ( roomIDCounter = int64(0) ) +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + type Room struct { ID string Version gomatrixserverlib.RoomVersion @@ -195,7 +199,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten if err != nil { t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err) } - if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil { t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) } headeredEvent := &rstypes.HeaderedEvent{PDU: ev} diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3cfdc0ce9..c025deee0 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms } if s.cfg.Matrix.ReportStats.Enabled { - go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID()) + go s.storeMessageStats(ctx, event.Type(), event.SenderID(), event.RoomID()) } log.WithFields(log.Fields{ @@ -301,7 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll) + sender := spec.UserID{} + userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { @@ -529,12 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype return fmt.Errorf("s.localPushDevices: %w", err) } + sender := spec.UserID{} + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it @@ -615,7 +625,12 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // evaluatePushRules fetches and evaluates the push rules of a local // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { - if event.Sender() == mem.UserID { + user := "" + sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil { + user = sender.String() + } + if user == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for // events that the user has sent themselves. return nil, nil @@ -632,9 +647,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * if err != nil { return nil, err } - sender := event.Sender() - if _, ok := ignored.List[sender]; ok { - return nil, fmt.Errorf("user %s is ignored", sender) + if _, ok := ignored.List[sender.String()]; ok { + return nil, fmt.Errorf("user %s is ignored", sender.String()) } } ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain) @@ -650,7 +664,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.PDU) + rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return nil, err } @@ -682,7 +698,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } -func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { +func (rse *ruleSetEvalContext) HasPowerLevel(senderID, levelKey string) (bool, error) { req := &rsapi.QueryLatestEventsAndStateRequest{ RoomID: rse.roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ @@ -702,7 +718,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err if err != nil { return false, err } - return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + return plc.UserLevel(senderID) >= plc.NotificationLevel(levelKey), nil } return true, nil } @@ -756,6 +772,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes } default: + sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err != nil { + logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) + return nil, err + } req = pushgateway.NotifyRequest{ Notification: pushgateway.Notification{ Content: event.Content(), @@ -767,7 +788,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes ID: event.EventID(), RoomID: event.RoomID(), RoomName: roomName, - Sender: event.Sender(), + Sender: sender.String(), Type: event.Type(), }, } diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 53977206f..899a5aaf0 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi/storage" @@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { return &types.HeaderedEvent{PDU: ev} } +type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } + +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func Test_evaluatePushRules(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - consumer := OutputRoomEventConsumer{db: db} + consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}} testCases := []struct { name string @@ -86,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message highlights", - eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`, + eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index e1c88d47f..27dd373c2 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -11,6 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "golang.org/x/crypto/bcrypt" @@ -87,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Prepare pusher with our test server URL - if err := db.UpsertPusher(ctx, api.Pusher{ + if err = db.UpsertPusher(ctx, api.Pusher{ Kind: api.HTTPKind, AppID: appID, PushKey: pushKey, @@ -99,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Insert a dummy event + sender, err := spec.NewUserID(alice.ID, true) + if err != nil { + t.Error(err) + } if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), }); err != nil { t.Error(err) } From 8ea1a11105ea7e66aa459537bcbef0de606147cd Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 7 Jun 2023 17:14:35 +0000 Subject: [PATCH 17/35] Use SenderID Type (#3105) --- appservice/consumers/roomserver.go | 2 +- clientapi/routing/directory.go | 16 +++- clientapi/routing/membership.go | 36 ++++++++- clientapi/routing/profile.go | 15 +++- clientapi/routing/redaction.go | 27 +++++-- clientapi/routing/sendevent.go | 21 +++++- clientapi/threepid/invites.go | 12 ++- cmd/resolve-state/main.go | 2 +- federationapi/federationapi_test.go | 13 +++- federationapi/internal/perform.go | 28 ++++--- federationapi/routing/invite.go | 4 +- federationapi/routing/join.go | 32 +++++--- federationapi/routing/leave.go | 32 +++++--- federationapi/routing/threepid.go | 14 ++-- go.mod | 2 +- go.sum | 4 +- internal/pushrules/evaluate.go | 2 +- internal/pushrules/evaluate_test.go | 8 +- internal/transactionrequest.go | 2 +- internal/transactionrequest_test.go | 8 +- roomserver/api/alias.go | 8 +- roomserver/api/api.go | 4 +- roomserver/internal/alias.go | 17 +++-- roomserver/internal/helpers/auth.go | 6 +- roomserver/internal/input/input_events.go | 12 +-- .../internal/input/input_events_test.go | 4 +- roomserver/internal/input/input_missing.go | 10 +-- roomserver/internal/perform/perform_admin.go | 26 ++++--- .../internal/perform/perform_backfill.go | 4 +- .../internal/perform/perform_create_room.go | 33 +++++++- roomserver/internal/perform/perform_invite.go | 10 ++- roomserver/internal/perform/perform_join.go | 17 +++-- roomserver/internal/perform/perform_leave.go | 19 +++-- .../internal/perform/perform_upgrade.go | 68 +++++++++++++---- roomserver/internal/query/query.go | 8 +- roomserver/roomserver_test.go | 12 +-- roomserver/state/state.go | 6 +- roomserver/storage/interface.go | 6 +- .../storage/shared/membership_updater.go | 2 +- roomserver/storage/shared/room_updater.go | 2 +- roomserver/storage/shared/storage.go | 12 +-- setup/mscs/msc2836/msc2836.go | 2 +- setup/mscs/msc2836/msc2836_test.go | 6 +- setup/mscs/msc2946/msc2946.go | 2 +- syncapi/consumers/roomserver.go | 2 +- syncapi/routing/context.go | 8 +- syncapi/routing/memberships.go | 19 ++++- syncapi/routing/messages.go | 6 +- syncapi/routing/search.go | 8 +- syncapi/routing/search_test.go | 4 +- syncapi/storage/interface.go | 6 +- syncapi/storage/shared/storage_consumer.go | 75 +++++++++---------- syncapi/storage/shared/storage_sync.go | 19 ++--- syncapi/storage/storage_test.go | 2 +- syncapi/streams/stream_pdu.go | 30 ++++---- syncapi/syncapi_test.go | 4 +- syncapi/types/types.go | 2 +- test/room.go | 6 +- userapi/consumers/roomserver.go | 6 +- userapi/consumers/roomserver_test.go | 4 +- 60 files changed, 502 insertions(+), 275 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 06625ad7e..ff124514e 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) sendEvents( // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), }, diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 0c842e6a5..034296f45 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -338,7 +338,21 @@ func SetVisibility( // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) - if power.UserLevel(dev.UserID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { + fullUserID, err := spec.NewUserID(dev.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("userID doesn't have power level to change visibility"), diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 0fe0a4ade..78829bec9 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -66,7 +66,21 @@ func SendBan( if errRes != nil { return *errRes } - allowedToBan := pl.UserLevel(device.UserID) >= pl.Ban + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), + } + } + allowedToBan := pl.UserLevel(senderID) >= pl.Ban if !allowedToBan { return util.JSONResponse{ Code: http.StatusForbidden, @@ -142,7 +156,21 @@ func SendKick( if errRes != nil { return *errRes } - allowedToKick := pl.UserLevel(device.UserID) >= pl.Kick + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } + } + allowedToKick := pl.UserLevel(senderID) >= pl.Kick if !allowedToKick { return util.JSONResponse{ Code: http.StatusForbidden, @@ -151,7 +179,7 @@ func SendKick( } var queryRes roomserverAPI.QueryMembershipForUserResponse - err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ + err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, UserID: body.UserID, }, &queryRes) @@ -319,7 +347,7 @@ func buildMembershipEventDirect( rsAPI roomserverAPI.ClientRoomserverAPI, ) (*types.HeaderedEvent, error) { proto := gomatrixserverlib.ProtoEvent{ - Sender: sender, + SenderID: sender, RoomID: roomID, Type: "m.room.member", StateKey: &targetUserID, diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 2c9d0cbbe..e734e2e4f 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -363,12 +363,21 @@ func buildMembershipEvents( ) ([]*types.HeaderedEvent, error) { evs := []*types.HeaderedEvent{} + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return nil, err + } for _, roomID := range roomIDs { + senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return nil, err + } + senderIDString := string(senderID) proto := gomatrixserverlib.ProtoEvent{ - Sender: userID, + SenderID: senderIDString, RoomID: roomID, Type: "m.room.member", - StateKey: &userID, + StateKey: &senderIDString, } content := gomatrixserverlib.MemberContent{ @@ -378,7 +387,7 @@ func buildMembershipEvents( content.DisplayName = newProfile.DisplayName content.AvatarURL = newProfile.AvatarURL - if err := proto.SetContent(content); err != nil { + if err = proto.SetContent(content); err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index e94c7748e..22474fc08 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -73,10 +73,25 @@ func SendRedaction( } } + fullUserID, userIDErr := spec.NewUserID(device.UserID, true) + if userIDErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) + if queryErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - allowedToRedact := ev.SenderID() == device.UserID // TODO: Should replace device.UserID with device...PerRoomKey + allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, @@ -97,7 +112,7 @@ func SendRedaction( ), } } - allowedToRedact = pl.UserLevel(device.UserID) >= pl.Redact + allowedToRedact = pl.UserLevel(senderID) >= pl.Redact } if !allowedToRedact { return util.JSONResponse{ @@ -114,10 +129,10 @@ func SendRedaction( // create the new event and set all the fields we can proto := gomatrixserverlib.ProtoEvent{ - Sender: device.UserID, - RoomID: roomID, - Type: spec.MRoomRedaction, - Redacts: eventID, + SenderID: string(senderID), + RoomID: roomID, + Type: spec.MRoomRedaction, + Redacts: eventID, } err := proto.SetContent(r) if err != nil { diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 8b09f399a..4d0a9f24a 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -266,16 +266,29 @@ func generateSendEvent( evTime time.Time, ) (gomatrixserverlib.PDU, *util.JSONResponse) { // parse the incoming http request - userID := device.UserID + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Unable to find senderID for user"), + } + } // create the new event and set all the fields we can proto := gomatrixserverlib.ProtoEvent{ - Sender: userID, + SenderID: string(senderID), RoomID: roomID, Type: eventType, StateKey: stateKey, } - err := proto.SetContent(r) + err = proto.SetContent(r) if err != nil { util.GetLogger(ctx).WithError(err).Error("proto.SetContent failed") return nil, &util.JSONResponse{ @@ -331,7 +344,7 @@ func generateSendEvent( stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, &util.JSONResponse{ diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 9f4f62e43..e7ffbac2b 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -355,8 +355,16 @@ func emit3PIDInviteEvent( rsAPI api.ClientRoomserverAPI, evTime time.Time, ) error { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return err + } + sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + if err != nil { + return err + } proto := &gomatrixserverlib.ProtoEvent{ - Sender: device.UserID, + SenderID: string(sender), RoomID: roomID, Type: "m.room.third_party_invite", StateKey: &res.Token, @@ -370,7 +378,7 @@ func emit3PIDInviteEvent( PublicKeys: res.PublicKeys, } - if err := proto.SetContent(content); err != nil { + if err = proto.SetContent(content); err != nil { return err } diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 360403094..15c87f1a8 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -183,7 +183,7 @@ func main() { fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) }, ) diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index a97bcdeab..173908437 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,8 +36,12 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } -func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + +func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { + return spec.SenderID(userID.String()), nil } // PerformJoin will call this function @@ -115,12 +119,13 @@ func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roo defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { if r.ID == roomID { + senderIDString := userID res.RoomVersion = r.Version res.JoinEvent = gomatrixserverlib.ProtoEvent{ - Sender: userID, + SenderID: senderIDString, RoomID: roomID, Type: "m.room.member", - StateKey: &userID, + StateKey: &senderIDString, Content: spec.RawJSON([]byte(`{"membership":"join"}`)), PrevEvents: r.ForwardExtremities(), } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 2d59d0f93..485b79a03 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -154,9 +154,14 @@ func (r *FederationInternalAPI) performJoinUsingServer( if err != nil { return err } + senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, roomID, *user) + if err != nil { + return err + } joinInput := gomatrixserverlib.PerformJoinInput{ UserID: user, + SenderID: senderID, RoomID: room, ServerName: serverName, Content: content, @@ -164,10 +169,10 @@ func (r *FederationInternalAPI) performJoinUsingServer( PrivateKey: r.cfg.Matrix.PrivateKey, KeyID: r.cfg.Matrix.KeyID, KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID, senderID string) (*spec.UserID, error) { + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, } @@ -363,7 +368,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - userIDProvider := func(roomID, senderID string) (*spec.UserID, error) { + userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( @@ -414,7 +419,7 @@ func (r *FederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) (err error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) + userID, err := spec.NewUserID(request.UserID, true) if err != nil { return err } @@ -433,7 +438,7 @@ func (r *FederationInternalAPI) PerformLeave( // request. respMakeLeave, err := r.federation.MakeLeave( ctx, - origin, + userID.Domain(), serverName, request.RoomID, request.UserID, @@ -454,9 +459,14 @@ func (r *FederationInternalAPI) PerformLeave( // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" + senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID) + if err != nil { + return err + } + senderIDString := string(senderID) respMakeLeave.LeaveEvent.Type = spec.MRoomMember - respMakeLeave.LeaveEvent.Sender = request.UserID - respMakeLeave.LeaveEvent.StateKey = &request.UserID + respMakeLeave.LeaveEvent.SenderID = senderIDString + respMakeLeave.LeaveEvent.StateKey = &senderIDString respMakeLeave.LeaveEvent.RoomID = request.RoomID respMakeLeave.LeaveEvent.Redacts = "" leaveEB := verImpl.NewEventBuilderFromProtoEvent(&respMakeLeave.LeaveEvent) @@ -478,7 +488,7 @@ func (r *FederationInternalAPI) PerformLeave( // Build the leave event. event, err := leaveEB.Build( time.Now(), - origin, + userID.Domain(), r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, ) @@ -490,7 +500,7 @@ func (r *FederationInternalAPI) PerformLeave( // Try to perform a send_leave using the newly built event. err = r.federation.SendLeave( ctx, - origin, + userID.Domain(), serverName, event, ) diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index d792335b9..5b15f810d 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,7 +95,7 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } @@ -188,7 +188,7 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 9da059189..d14801921 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -55,7 +55,7 @@ func MakeJoin( RoomID: roomID.String(), } res := api.QueryServerJoinedToRoomResponse{} - if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { + if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -64,26 +64,26 @@ func MakeJoin( } createJoinTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { - identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) + identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination()) + if signErr != nil { + util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination()) return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) } queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: roomVersion, } - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) - switch e := err.(type) { + event, signErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) + switch e := signErr.(type) { case nil: case eventutil.ErrRoomNoExists: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.NotFound("Room does not exist") case gomatrixserverlib.BadJSONError: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.BadJSON(e.Error()) default: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.InternalServerError{} } @@ -98,9 +98,19 @@ func MakeJoin( Roomserver: rsAPI, } + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + input := gomatrixserverlib.HandleMakeJoinInput{ Context: httpReq.Context(), UserID: userID, + SenderID: senderID, RoomID: roomID, RoomVersion: roomVersion, RemoteVersions: remoteVersions, @@ -108,7 +118,7 @@ func MakeJoin( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, RoomQuerier: &roomQuerier, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, BuildEventTemplate: createJoinTemplate, @@ -205,7 +215,7 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 30e99c4f7..716276bec 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -50,7 +50,7 @@ func MakeLeave( RoomID: roomID.String(), } res := api.QueryServerJoinedToRoomResponse{} - if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { + if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -59,24 +59,24 @@ func MakeLeave( } createLeaveTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { - identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) + identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination()) + if signErr != nil { + util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination()) return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) - switch e := err.(type) { + event, buildErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) + switch e := buildErr.(type) { case nil: case eventutil.ErrRoomNoExists: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.NotFound("Room does not exist") case gomatrixserverlib.BadJSONError: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.BadJSON(e.Error()) default: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.InternalServerError{} } @@ -87,15 +87,25 @@ func MakeLeave( return event, stateEvents, nil } + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + input := gomatrixserverlib.HandleMakeLeaveInput{ UserID: userID, + SenderID: senderID, RoomID: roomID, RoomVersion: roomVersion, RequestOrigin: request.Origin(), LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } @@ -216,7 +226,7 @@ func SendLeave( JSON: spec.BadJSON("No state key was provided in the leave event."), } } - if !event.StateKeyEquals(event.SenderID()) { + if !event.StateKeyEquals(string(event.SenderID())) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("Event state key must match the event sender."), diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 76a2f3d5a..360802de5 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -140,22 +140,24 @@ func ExchangeThirdPartyInvite( } } - _, senderDomain, err := cfg.Matrix.SplitLocalID('@', proto.Sender) - if err != nil { + userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID)) + if err != nil || userID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid sender ID: " + err.Error()), + JSON: spec.BadJSON("Invalid sender ID"), } } + senderDomain := userID.Domain() // Check that the state key is correct. - _, targetDomain, err := gomatrixserverlib.SplitID('@', *proto.StateKey) - if err != nil { + targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey)) + if err != nil || targetUserID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("The event's state key isn't a Matrix user ID"), } } + targetDomain := targetUserID.Domain() // Check that the target user is from the requesting homeserver. if targetDomain != request.Origin() { @@ -271,7 +273,7 @@ func createInviteFrom3PIDInvite( // Build the event proto := &gomatrixserverlib.ProtoEvent{ Type: "m.room.member", - Sender: inv.Sender, + SenderID: inv.Sender, RoomID: inv.RoomID, StateKey: &inv.MXID, } diff --git a/go.mod b/go.mod index 10551f702..3621428c3 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 3ec1c115c..1ee0261f6 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 h1:6SixhMmB5Ir10xUJ6zh3A4NBxSaZCSz2s5U63Wg0eEU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index da33d3862..ac7608950 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -28,7 +28,7 @@ type EvaluationContext interface { // HasPowerLevel returns whether the user has at least the given // power in the room of the current event. - HasPowerLevel(userID, levelKey string) (bool, error) + HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) } // A kindAndRules is just here to simplify iteration of the (ordered) diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 34c1436f4..859d1f8a6 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -8,8 +8,8 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func TestRuleSetEvaluatorMatchEvent(t *testing.T) { @@ -158,8 +158,8 @@ type fakeEvaluationContext struct{ memberCount int } func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" } func (f fakeEvaluationContext) RoomMemberCount() (int, error) { return f.memberCount, nil } -func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) { - return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil +func (fakeEvaluationContext) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) { + return senderID == "@poweruser:example.com" && levelKey == "powerlevel", nil } func TestPatternMatches(t *testing.T) { diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 0bbe0720c..b2929bb5d 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index 6f3ce0b3b..1d32c8060 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,8 +70,8 @@ type FakeRsAPI struct { bannedFromRoom bool } -func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func (r *FakeRsAPI) QueryRoomVersionForRoom( @@ -642,8 +642,8 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } -func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func (t *testRoomserverAPI) InputRoomEvents( diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 1b9475404..c091cf6a3 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -14,7 +14,11 @@ package api -import "regexp" +import ( + "regexp" + + "github.com/matrix-org/gomatrixserverlib/spec" +) // SetRoomAliasRequest is a request to SetRoomAlias type SetRoomAliasRequest struct { @@ -62,7 +66,7 @@ type GetAliasesForRoomIDResponse struct { // RemoveRoomAliasRequest is a request to RemoveRoomAlias type RemoveRoomAliasRequest struct { // ID of the user removing the alias - SenderID string `json:"user_id"` + SenderID spec.SenderID `json:"user_id"` // The room alias to remove Alias string `json:"alias"` } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index d61a05534..8c2cbd6b2 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -77,8 +77,8 @@ type InputRoomEventsAPI interface { } type QuerySenderIDAPI interface { - QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) - QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) + QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) + QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) } // Query the latest events and state for a room from the room server. diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index dcfb26b8e..c950024ad 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -130,7 +130,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( } sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) - if err != nil { + if err != nil || sender == nil { return fmt.Errorf("r.QueryUserIDForSender: %w", err) } virtualHost := sender.Domain() @@ -141,7 +141,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) } - if creatorID != request.SenderID { + if spec.SenderID(creatorID) != request.SenderID { var plEvent *types.HeaderedEvent var pls *gomatrixserverlib.PowerLevelContent @@ -173,23 +173,24 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - sender := request.SenderID + senderID := request.SenderID if request.SenderID != ev.SenderID() { - sender = ev.SenderID() + senderID = ev.SenderID() } - - _, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) - if err != nil { + sender, err := r.QueryUserIDForSender(ctx, roomID, senderID) + if err != nil || sender == nil { return err } + senderDomain := sender.Domain() + identity, err := r.Cfg.Global.SigningIdentityFor(senderDomain) if err != nil { return err } proto := &gomatrixserverlib.ProtoEvent{ - Sender: sender, + SenderID: string(senderID), RoomID: ev.RoomID(), Type: ev.Type(), StateKey: ev.StateKey(), diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 932ce6155..7782d07d2 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -76,7 +76,7 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { // return true, nil @@ -139,8 +139,8 @@ func (ae *authEvents) JoinRules() (gomatrixserverlib.PDU, error) { } // Memmber implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) Member(stateKey string) (gomatrixserverlib.PDU, error) { - return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil +func (ae *authEvents) Member(stateKey spec.SenderID) (gomatrixserverlib.PDU, error) { + return ae.lookupEvent(types.MRoomMemberNID, string(stateKey)), nil } // ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 764bdfe2c..1f273da01 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -282,7 +282,7 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { isRejected = true @@ -501,7 +501,7 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { oldRoomID := event.RoomID() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str - return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.SenderID()) + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, string(event.SenderID())) } // processStateBefore works out what the state is before the event and @@ -587,7 +587,7 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) { + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) @@ -700,7 +700,7 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue nextAuthEvent @@ -718,7 +718,7 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) { + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) if isRejected = err != nil; isRejected { @@ -875,7 +875,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r RoomID: event.RoomID(), Type: spec.MRoomMember, StateKey: &stateKey, - Sender: stateKey, + SenderID: stateKey, PrevEvents: prevEvents, } diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 0ba7d19f5..5f2cd9562 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,9 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) + }); err == nil { t.Fatalf("event should not be allowed, but it was") } } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index ac0670fc3..f0f974d26 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -473,7 +473,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) { + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return t.db.GetUserIDForSender(ctx, roomID, senderID) }, ) @@ -482,7 +482,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) { + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { switch missing := err.(type) { @@ -569,7 +569,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue @@ -660,7 +660,7 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) { + }, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return t.db.GetUserIDForSender(ctx, roomID, senderID) }) if err != nil { @@ -897,7 +897,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index ca736cb65..eeb1ac406 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -96,14 +96,15 @@ func (r *Admin) PerformAdminEvacuateRoom( RoomID: roomID, Type: spec.MRoomMember, StateKey: &stateKey, - Sender: stateKey, + SenderID: stateKey, PrevEvents: prevEvents, } - _, senderDomain, err = gomatrixserverlib.SplitID('@', fledglingEvent.Sender) - if err != nil { + userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID)) + if err != nil || userID == nil { continue } + senderDomain = userID.Domain() if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { return nil, err @@ -233,10 +234,11 @@ func (r *Admin) PerformAdminDownloadState( ctx context.Context, roomID, userID string, serverName spec.ServerName, ) error { - _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) + fullUserID, err := spec.NewUserID(userID, true) if err != nil { return err } + senderDomain := fullUserID.Domain() roomInfo, err := r.DB.RoomInfo(ctx, roomID) if err != nil { @@ -262,7 +264,7 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue @@ -270,7 +272,7 @@ func (r *Admin) PerformAdminDownloadState( authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue @@ -291,11 +293,15 @@ func (r *Admin) PerformAdminDownloadState( stateIDs = append(stateIDs, stateEvent.EventID()) } + senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return err + } proto := &gomatrixserverlib.ProtoEvent{ - Type: "org.matrix.dendrite.state_download", - Sender: userID, - RoomID: roomID, - Content: spec.RawJSON("{}"), + Type: "org.matrix.dendrite.state_download", + SenderID: string(senderID), + RoomID: roomID, + Content: spec.RawJSON("{}"), } eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto) diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 0f743f4e4..388150936 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -121,7 +121,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) { + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) @@ -212,7 +212,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) { + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) if err != nil { diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 897bd3a0e..a3ba20f70 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -270,11 +270,19 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo var builtEvents []*types.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) + senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } for i, e := range eventsToMake { depth := i + 1 // depth starts at 1 builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ - Sender: userID.String(), + SenderID: string(senderID), RoomID: roomID.String(), Type: e.Type, StateKey: &e.StateKey, @@ -308,7 +316,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return c.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") @@ -409,11 +417,28 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo // Process the invites. var inviteEvent *types.HeaderedEvent for _, invitee := range createRequest.InvitedUsers { + inviteeUserID, userIDErr := spec.NewUserID(invitee, true) + if userIDErr != nil { + util.GetLogger(ctx).WithError(userIDErr).Error("invalid UserID") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID) + if queryErr != nil { + util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + inviteeString := string(inviteeSenderID) proto := gomatrixserverlib.ProtoEvent{ - Sender: userID.String(), + SenderID: string(senderID), RoomID: roomID.String(), Type: "m.room.member", - StateKey: &invitee, + StateKey: &inviteeString, } content := gomatrixserverlib.MemberContent{ diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e8e20ede2..56ee16065 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -98,7 +98,7 @@ func (r *Inviter) ProcessInviteMembership( var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey()) + userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } @@ -148,15 +148,21 @@ func (r *Inviter) PerformInvite( return err } + invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser) + if err != nil { + return fmt.Errorf("failed looking up senderID for invited user") + } + input := gomatrixserverlib.PerformInviteInput{ RoomID: *validRoomID, InviteEvent: event.PDU, InvitedUser: *invitedUser, + InvitedSenderID: invitedSenderID, IsTargetLocal: isTargetLocal, StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, StateQuerier: &QueryState{r.DB}, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 181a93490..d41cc214b 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -175,15 +175,20 @@ func (r *Joiner) performJoinRoomByID( } // Prepare the template for the join event. - userID := req.UserID - _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) + userID, err := spec.NewUserID(req.UserID, true) if err != nil { - return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", userID, err)} + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} } + senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID) + if err != nil { + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} + } + senderIDString := string(senderID) + userDomain := userID.Domain() proto := gomatrixserverlib.ProtoEvent{ Type: spec.MRoomMember, - Sender: userID, - StateKey: &userID, + SenderID: senderIDString, + StateKey: &senderIDString, RoomID: req.RoomIDOrAlias, Redacts: "", } @@ -295,7 +300,7 @@ func (r *Joiner) performJoinRoomByID( // is really no harm in just sending another membership event. membershipReq := &api.QueryMembershipForUserRequest{ RoomID: req.RoomIDOrAlias, - UserID: userID, + UserID: userID.String(), } membershipRes := &api.QueryMembershipForUserResponse{} _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 90102aeeb..094537f8b 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -152,11 +152,19 @@ func (r *Leaver) performLeaveRoomByID( } // Prepare the template for the leave event. - userID := req.UserID + fullUserID, err := spec.NewUserID(req.UserID, true) + if err != nil { + return nil, err + } + senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID) + if err != nil { + return nil, err + } + senderIDString := string(senderID) proto := gomatrixserverlib.ProtoEvent{ Type: spec.MRoomMember, - Sender: userID, - StateKey: &userID, + SenderID: senderIDString, + StateKey: &senderIDString, RoomID: req.RoomID, Redacts: "", } @@ -168,10 +176,7 @@ func (r *Leaver) performLeaveRoomByID( } // Get the sender domain. - _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', proto.Sender) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", proto.Sender) - } + senderDomain := fullUserID.Domain() // We know that the user is in the room at this point so let's build // a leave event. diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 8c0df1c46..5710352bb 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -175,8 +175,16 @@ func moveLocalAliases(ctx context.Context, return fmt.Errorf("Failed to get old room aliases: %w", err) } + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return fmt.Errorf("Failed to get userID: %w", err) + } + senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return fmt.Errorf("Failed to get senderID: %w", err) + } for _, alias := range aliasRes.Aliases { - removeAliasReq := api.RemoveRoomAliasRequest{SenderID: userID, Alias: alias} + removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias} removeAliasRes := api.RemoveRoomAliasResponse{} if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { return fmt.Errorf("Failed to remove old room alias: %w", err) @@ -287,7 +295,15 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, } // Check for power level required to send tombstone event (marks the current room as obsolete), // if not found, use the StateDefault power level - return pl.UserLevel(userID) >= pl.EventLevel("m.room.tombstone", true) + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return false + } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return false + } + return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true) } // nolint:gocyclo @@ -383,7 +399,16 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query util.GetLogger(ctx).WithError(err).Error() return nil, fmt.Errorf("Power level event content was invalid") } - tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, userID) + + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return nil, err + } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return nil, err + } + tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID) // Now do the join rules event, same as the create and membership // events. We'll set a sane default of "invite" so that if the @@ -452,8 +477,16 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user for i, e := range eventsToMake { depth := i + 1 // depth starts at 1 + fullUserID, userIDErr := spec.NewUserID(userID, true) + if userIDErr != nil { + return userIDErr + } + senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID) + if queryErr != nil { + return queryErr + } proto := gomatrixserverlib.ProtoEvent{ - Sender: userID, + SenderID: string(senderID), RoomID: newRoomID, Type: e.Type, StateKey: &e.StateKey, @@ -484,7 +517,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user } - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) @@ -530,21 +563,26 @@ func (r *Upgrader) makeTombstoneEvent( } func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return nil, err + } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + if err != nil { + return nil, err + } proto := gomatrixserverlib.ProtoEvent{ - Sender: userID, + SenderID: string(senderID), RoomID: roomID, Type: event.Type, StateKey: &event.StateKey, } - err := proto.SetContent(event.Content) + err = proto.SetContent(event.Content) if err != nil { return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err) } // Get the sender domain. - _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', proto.Sender) - if serr != nil { - return nil, fmt.Errorf("Failed to split user ID %q: %w", proto.Sender, err) - } + senderDomain := fullUserID.Domain() identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) if err != nil { return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) @@ -569,7 +607,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? @@ -578,7 +616,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user return headeredEvent, nil } -func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, userID string) (gomatrixserverlib.FledglingEvent, bool) { +func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, senderID spec.SenderID) (gomatrixserverlib.FledglingEvent, bool) { // Work out what power level we need in order to be able to send events // of all types into the room. neededPowerLevel := powerLevelContent.StateDefault @@ -603,8 +641,8 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC // If the user who is upgrading the room doesn't already have sufficient // power, then elevate their power levels. - if tempPowerLevelContent.UserLevel(userID) < neededPowerLevel { - tempPowerLevelContent.Users[userID] = neededPowerLevel + if tempPowerLevelContent.UserLevel(senderID) < neededPowerLevel { + tempPowerLevelContent.Users[string(senderID)] = neededPowerLevel powerLevelsOverridden = true } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 707e95b2a..ae2b7cf57 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -159,7 +159,7 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) @@ -637,7 +637,7 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) @@ -975,10 +975,10 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) } -func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { return r.DB.GetSenderIDForUser(ctx, roomID, userID) } -func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.DB.GetUserIDForSender(ctx, roomID, senderID) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 11a0f5817..5e6ba7d4e 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -392,7 +392,7 @@ func TestPurgeRoom(t *testing.T) { type fledglingEvent struct { Type string StateKey *string - Sender string + SenderID string RoomID string Redacts string Depth int64 @@ -405,7 +405,7 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEve seed := make([]byte, ed25519.SeedSize) // zero seed key := ed25519.NewKeyFromSeed(seed) eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ - Sender: ev.Sender, + SenderID: ev.SenderID, Type: ev.Type, StateKey: ev.StateKey, RoomID: ev.RoomID, @@ -444,7 +444,7 @@ func TestRedaction(t *testing.T) { builderEv := mustCreateEvent(t, fledglingEvent{ Type: spec.MRoomRedaction, - Sender: alice.ID, + SenderID: alice.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), Depth: redactedEvent.Depth() + 1, @@ -461,7 +461,7 @@ func TestRedaction(t *testing.T) { builderEv := mustCreateEvent(t, fledglingEvent{ Type: spec.MRoomRedaction, - Sender: alice.ID, + SenderID: alice.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), Depth: redactedEvent.Depth() + 1, @@ -478,7 +478,7 @@ func TestRedaction(t *testing.T) { builderEv := mustCreateEvent(t, fledglingEvent{ Type: spec.MRoomRedaction, - Sender: bob.ID, + SenderID: bob.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), Depth: redactedEvent.Depth() + 1, @@ -494,7 +494,7 @@ func TestRedaction(t *testing.T) { builderEv := mustCreateEvent(t, fledglingEvent{ Type: spec.MRoomRedaction, - Sender: charlie.ID, + SenderID: charlie.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), Depth: redactedEvent.Depth() + 1, diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 3131cbff2..b9c5bbc4a 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -44,7 +44,7 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) - GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) } type StateResolution struct { @@ -947,7 +947,7 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return v.db.GetUserIDForSender(ctx, roomID, senderID) }) @@ -1061,7 +1061,7 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, - func(roomID, senderID string) (*spec.UserID, error) { + func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return v.db.GetUserIDForSender(ctx, roomID, senderID) }, ) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 2d007bed5..523cc361a 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -167,9 +167,9 @@ type Database interface { // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownUsers tries to obtain the current mxid for a given user. - GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) // GetKnownUsers tries to obtain the current senderID for a given user. - GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) + GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -215,7 +215,7 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) - GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 105e61df6..a96e87072 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event * var inserted bool // Did the query result in a membership change? var retired []string // Did we retire any updates in the process? return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.SenderID()) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, string(event.SenderID())) if err != nil { return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 735001383..6fb57332a 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -252,6 +252,6 @@ func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, ta return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } -func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { +func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { return u.d.GetUserIDForSender(ctx, roomID, senderID) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 406d7cf1c..f2f842357 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -990,13 +990,13 @@ func (d *EventDatabase) MaybeRedactEvent( // TODO: Don't hack senderID into userID here (pseudoIDs) sender1Domain := "" - sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true) + sender1, err1 := spec.NewUserID(string(redactedEvent.SenderID()), true) if err1 == nil { sender1Domain = string(sender1.Domain()) } // TODO: Don't hack senderID into userID here (pseudoIDs) sender2Domain := "" - sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true) + sender2, err2 := spec.NewUserID(string(redactionEvent.SenderID()), true) if err2 == nil { sender2Domain = string(sender2.Domain()) } @@ -1524,14 +1524,14 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } -func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { +func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { // TODO: Use real logic once DB for pseudoIDs is in place - return spec.NewUserID(senderID, true) + return spec.NewUserID(string(senderID), true) } -func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { +func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { // TODO: Use real logic once DB for pseudoIDs is in place - return userID.String(), nil + return spec.SenderID(userID.String()), nil } // GetKnownRooms returns a list of all rooms we know about. diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 5ce3b430b..47eb544ea 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -94,7 +94,7 @@ type MSC2836EventRelationshipsResponse struct { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), Limited: res.Limited, diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index c463fd72b..551d7ad45 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -525,8 +525,8 @@ type testRoomserverAPI struct { events map[string]*types.HeaderedEvent } -func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { @@ -590,7 +590,7 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEve seed := make([]byte, ed25519.SeedSize) // zero seed key := ed25519.NewKeyFromSeed(seed) eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ - Sender: ev.Sender, + SenderID: ev.Sender, Depth: 999, Type: ev.Type, StateKey: ev.StateKey, diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index f380d3d4f..3e5ffda92 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -730,7 +730,7 @@ func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent { Type: ev.Type(), StateKey: *ev.StateKey(), Content: ev.Content(), - Sender: ev.SenderID(), + Sender: string(ev.SenderID()), OriginServerTS: ev.OriginServerTS(), } } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index c08364658..8a2a0b1f6 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -523,7 +523,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) prev := types.PrevEventRef{ PrevContent: prevEvent.Content(), ReplacesState: prevEvent.EventID(), - PrevSender: prevEvent.SenderID(), + PrevSenderID: string(prevEvent.SenderID()), } event.PDU, err = event.SetUnsigned(prev) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 27e99a357..7fb88faaa 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -193,10 +193,10 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) @@ -204,7 +204,7 @@ func Context( if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) @@ -227,7 +227,7 @@ func Context( Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 9c2319dd9..813167a5e 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -144,7 +144,22 @@ func GetMemberships( JSON: spec.InternalServerError{}, } } - res.Joined[ev.SenderID()] = joinedMember(content) + + userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + if err != nil || userID == nil { + util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } + } + res.Joined[userID.String()] = joinedMember(content) } return util.JSONResponse{ Code: http.StatusOK, @@ -153,7 +168,7 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })}, } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 879739d00..781fd53e7 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -273,7 +273,7 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })...) } @@ -385,7 +385,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), start, end, err } @@ -495,7 +495,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent } // Append the events ve previously retrieved locally. - events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...) + events = append(events, r.snapshot.StreamEventsToEvents(r.ctx, nil, streamEvents, r.rsAPI)...) sort.Sort(eventsByDepth(events)) return diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 9cf3eabe2..add50b181 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -213,7 +213,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profile, ok := knownUsersProfiles[userID.String()] if !ok { - stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) + stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, string(ev.SenderID())) if stateErr != nil { logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue @@ -239,10 +239,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts Context: SearchContextResponse{ Start: startToken.String(), End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), ProfileInfo: profileInfos, @@ -263,7 +263,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }) } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index b36be8238..5eb094ca3 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -25,8 +25,8 @@ import ( type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } -func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func TestSearch(t *testing.T) { diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 302b9bad8..8798b62ec 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -44,8 +44,8 @@ type DatabaseTransaction interface { MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error) - GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) - GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) + GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error) + GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) @@ -90,7 +90,7 @@ type DatabaseTransaction interface { // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. - StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent + StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the // relevant events within the given ranges for the supplied user ID and device ID. SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 17a6a69c3..5bd3b1f01 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -99,7 +99,41 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*rstypes.He // We don't include a device here as we only include transaction IDs in // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil +} + +func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent { + out := make([]*rstypes.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[i].HeaderedEvent + if device != nil && in[i].TransactionID != nil { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { + err := out[i].SetUnsignedField( + "transaction_id", in[i].TransactionID.TransactionID, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + } + } + } + } + return out } // AddInviteEvent stores a new invite event for a user. @@ -190,45 +224,6 @@ func (d *Database) UpsertAccountData( return } -func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent { - out := make([]*rstypes.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[i].HeaderedEvent - if device != nil && in[i].TransactionID != nil { - userID, err := spec.NewUserID(device.UserID, true) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - continue - } - deviceSenderID, err := d.getSenderIDForUser(in[i].RoomID(), *userID) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - continue - } - if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { - err := out[i].SetUnsignedField( - "transaction_id", in[i].TransactionID.TransactionID, - ) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - } - } - } - } - return out -} - -func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { // nolint - // TODO: Repalce with actual logic for pseudoIDs - return userID.String(), nil -} - // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index f2b1c58dc..df9613850 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -10,6 +10,7 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" @@ -186,7 +187,7 @@ func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([] // We don't include a device here as we only include transaction IDs in // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil } func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { @@ -325,7 +326,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos( func (d *DatabaseTransaction) GetStateDeltas( ctx context.Context, device *userapi.Device, r types.Range, userID string, - stateFilter *synctypes.StateFilter, + stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI, ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response @@ -417,7 +418,7 @@ func (d *DatabaseTransaction) GetStateDeltas( if !peek.Deleted { deltas = append(deltas, types.StateDelta{ Membership: spec.Peek, - StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), + StateEvents: d.StreamEventsToEvents(ctx, device, state[peek.RoomID], rsAPI), RoomID: peek.RoomID, }) } @@ -462,7 +463,7 @@ func (d *DatabaseTransaction) GetStateDeltas( deltas = append(deltas, types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), + StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[roomID], rsAPI), RoomID: roomID, }) break @@ -474,7 +475,7 @@ func (d *DatabaseTransaction) GetStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ Membership: spec.Join, - StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), + StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[joinedRoomID], rsAPI), RoomID: joinedRoomID, NewlyJoined: newlyJoinedRooms[joinedRoomID], }) @@ -490,7 +491,7 @@ func (d *DatabaseTransaction) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( ctx context.Context, device *userapi.Device, r types.Range, userID string, - stateFilter *synctypes.StateFilter, + stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI, ) ([]types.StateDelta, []string, error) { // Look up all memberships for the user. We only care about rooms that a // user has ever interacted with — joined to, kicked/banned from, left. @@ -531,7 +532,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( } deltas[peek.RoomID] = types.StateDelta{ Membership: spec.Peek, - StateEvents: d.StreamEventsToEvents(device, s), + StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI), RoomID: peek.RoomID, } } @@ -560,7 +561,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( deltas[roomID] = types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + StateEvents: d.StreamEventsToEvents(ctx, device, stateStreamEvents, rsAPI), RoomID: roomID, } } @@ -581,7 +582,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( } deltas[joinedRoomID] = types.StateDelta{ Membership: spec.Join, - StateEvents: d.StreamEventsToEvents(device, s), + StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI), RoomID: joinedRoomID, } } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 08ca99a76..bc64aa50f 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -214,7 +214,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { if err != nil { t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) } - gots := snapshot.StreamEventsToEvents(nil, paginatedEvents) + gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil) test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) }) }) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 8f83a0896..d214980bd 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -175,12 +175,12 @@ func (p *PDUStreamProvider) IncrementalSync( eventFilter := req.Filter.Room.Timeline if req.WantFullState { - if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") return from } } else { - if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") return from } @@ -275,7 +275,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( limited := dbEvents[delta.RoomID].Limited recEvents := gomatrixserverlib.ReverseTopologicalOrdering( - gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(device, recentStreamEvents)), + gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)), gomatrixserverlib.TopologicalOrderByPrevEvents, ) recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents)) @@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Join[delta.RoomID] = jr @@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Peek[delta.RoomID] = jr @@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Leave[delta.RoomID] = lr @@ -437,7 +437,7 @@ func applyHistoryVisibilityFilter( for _, ev := range recentEvents { if ev.StateKey() != nil { stateTypes = append(stateTypes, ev.Type()) - senders = append(senders, ev.SenderID()) + senders = append(senders, string(ev.SenderID())) } } @@ -512,7 +512,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // "Can sync a room with a message with a transaction id" - which does a complete sync to check. - recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) + recentEvents := snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI) events := recentEvents // Only apply history visibility checks if the response is for joined rooms @@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) return jr, nil @@ -593,8 +593,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( // Add all users the client doesn't know about yet to a list for _, event := range timelineEvents { // Membership is not yet cached, add it to the list - if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.SenderID()); !ok { - timelineUsers[event.SenderID()] = struct{}{} + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, string(event.SenderID())); !ok { + timelineUsers[string(event.SenderID())] = struct{}{} } } // Preallocate with the same amount, even if it will end up with fewer values diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 78c857ab9..b9f13c517 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,8 +40,8 @@ type syncRoomserverAPI struct { rooms []*test.Room } -func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 526a120d0..a3dc7f54b 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -343,7 +343,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { type PrevEventRef struct { PrevContent json.RawMessage `json:"prev_content"` ReplacesState string `json:"replaces_state"` - PrevSender string `json:"prev_sender"` + PrevSenderID string `json:"prev_sender"` } type DeviceLists struct { diff --git a/test/room.go b/test/room.go index 4cdb73aa3..b19c57ddc 100644 --- a/test/room.go +++ b/test/room.go @@ -39,8 +39,8 @@ var ( roomIDCounter = int64(0) ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } type Room struct { @@ -168,7 +168,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten } builder := gomatrixserverlib.MustGetRoomVersion(r.Version).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ - Sender: creator.ID, + SenderID: creator.ID, RoomID: r.ID, Type: eventType, StateKey: mod.stateKey, diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index c025deee0..df507eb26 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms } if s.cfg.Matrix.ReportStats.Enabled { - go s.storeMessageStats(ctx, event.Type(), event.SenderID(), event.RoomID()) + go s.storeMessageStats(ctx, event.Type(), string(event.SenderID()), event.RoomID()) } log.WithFields(log.Fields{ @@ -664,7 +664,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) { + rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { @@ -698,7 +698,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } -func (rse *ruleSetEvalContext) HasPowerLevel(senderID, levelKey string) (bool, error) { +func (rse *ruleSetEvalContext) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) { req := &rsapi.QueryLatestEventsAndStateRequest{ RoomID: rse.roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 899a5aaf0..954247155 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -47,8 +47,8 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } -func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func Test_evaluatePushRules(t *testing.T) { From 5713c5715c72953272b7b99fe64feb29bf1fbe6f Mon Sep 17 00:00:00 2001 From: Antonio Cheong Date: Mon, 12 Jun 2023 16:51:26 +0800 Subject: [PATCH 18/35] Update sample link (#3107) Leftover work by f956a8c1d9172f6bbfb9f7515feacd477a0e35f5 Signed-off-by: `Antonio Cheong ` [skip ci] --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0b9788768..34604eff9 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ For a usable federating Dendrite deployment, you will also need: Also recommended are: - A PostgreSQL database engine, which will perform better than SQLite with many users and/or larger rooms -- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) +- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/dendrite-sample.conf) The [Federation Tester](https://federationtester.matrix.org) can be used to verify your deployment. From 832ccc32f6a023665e250eee44b5f678e985d50e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 12 Jun 2023 12:45:42 +0200 Subject: [PATCH 19/35] Add initial support for storing user room keys (#3098) --- roomserver/storage/interface.go | 16 ++ roomserver/storage/postgres/storage.go | 9 ++ .../storage/postgres/user_room_keys_table.go | 132 ++++++++++++++++ roomserver/storage/shared/storage.go | 146 ++++++++++++++++++ roomserver/storage/shared/storage_test.go | 116 +++++++++++++- roomserver/storage/sqlite3/storage.go | 8 + .../storage/sqlite3/user_room_keys_table.go | 146 ++++++++++++++++++ roomserver/storage/tables/interface.go | 14 ++ .../tables/user_room_keys_table_test.go | 115 ++++++++++++++ roomserver/types/types.go | 5 + 10 files changed, 700 insertions(+), 7 deletions(-) create mode 100644 roomserver/storage/postgres/user_room_keys_table.go create mode 100644 roomserver/storage/sqlite3/user_room_keys_table.go create mode 100644 roomserver/storage/tables/user_room_keys_table_test.go diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 523cc361a..2d27d7999 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -16,6 +16,7 @@ package storage import ( "context" + "crypto/ed25519" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -27,6 +28,7 @@ import ( ) type Database interface { + UserRoomKeys // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. @@ -194,8 +196,22 @@ type Database interface { ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) } +type UserRoomKeys interface { + // InsertUserRoomPrivatePublicKey inserts the given private key as well as the public key for it. This should be used + // when creating keys locally. + InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) + // InsertUserRoomPublicKey inserts the given public key, this should be used for users NOT local to this server + InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) + // SelectUserRoomPrivateKey selects the private key for the given user and room combination + SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) + // SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID. + // If a senderKey can't be found, it is omitted in the result. + SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) +} + type RoomDatabase interface { EventDatabase + UserRoomKeys // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 19cde5410..453ff45da 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -131,6 +131,9 @@ func (d *Database) create(db *sql.DB) error { if err := CreateRedactionsTable(db); err != nil { return err } + if err := CreateUserRoomKeysTable(db); err != nil { + return err + } return nil } @@ -192,6 +195,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + userRoomKeys, err := PrepareUserRoomKeysTable(db) + if err != nil { + return err + } + d.Database = shared.Database{ DB: db, EventDatabase: shared.EventDatabase{ @@ -215,6 +223,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room MembershipTable: membership, PublishedTable: published, Purge: purge, + UserRoomKeyTable: userRoomKeys, } return nil } diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go new file mode 100644 index 000000000..22f978bf0 --- /dev/null +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -0,0 +1,132 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "crypto/ed25519" + "database/sql" + "errors" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const userRoomKeysSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( + user_nid INTEGER NOT NULL, + room_nid INTEGER NOT NULL, + pseudo_id_key BYTEA NULL, -- may be null for users not local to the server + pseudo_id_pub_key BYTEA NOT NULL, + CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) +); +` + +const insertUserRoomPrivateKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key + RETURNING (pseudo_id_key) +` + +const insertUserRoomPublicKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_pub_key = $3 + RETURNING (pseudo_id_pub_key) +` + +const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + +const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` + +type userRoomKeysStatements struct { + insertUserRoomPrivateKeyStmt *sql.Stmt + insertUserRoomPublicKeyStmt *sql.Stmt + selectUserRoomKeyStmt *sql.Stmt + selectUserNIDsStmt *sql.Stmt +} + +func CreateUserRoomKeysTable(db *sql.DB) error { + _, err := db.Exec(userRoomKeysSchema) + return err +} + +func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { + s := &userRoomKeysStatements{} + return s, sqlutil.StatementList{ + {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, + {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, + {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserNIDsStmt, selectUserNIDsSQL}, + }.Prepare(db) +} + +func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PrivateKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt) + var result ed25519.PrivateKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + +func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) + + roomNIDs := make([]types.RoomNID, 0, len(senderKeys)) + var senders [][]byte + for roomNID := range senderKeys { + roomNIDs = append(roomNIDs, roomNID) + for _, key := range senderKeys[roomNID] { + senders = append(senders, key) + } + } + rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(senders)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + result := make(map[string]types.UserRoomKeyPair, len(senders)+len(roomNIDs)) + var publicKey []byte + userRoomKeyPair := types.UserRoomKeyPair{} + for rows.Next() { + if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { + return nil, err + } + result[string(publicKey)] = userRoomKeyPair + } + return result, rows.Err() +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index f2f842357..cb12b3f57 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -2,14 +2,18 @@ package shared import ( "context" + "crypto/ed25519" "database/sql" "encoding/json" + "errors" "fmt" "sort" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/caching" @@ -41,6 +45,7 @@ type Database struct { MembershipTable tables.Membership PublishedTable tables.Published Purge tables.Purge + UserRoomKeyTable tables.UserRoomKeys GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } @@ -1609,6 +1614,147 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS }) } +// InsertUserRoomPrivatePublicKey inserts a new user room key for the given user and room. +// Returns the newly inserted private key or an existing private key. If there is +// an error talking to the database, returns that error. +func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + + var iErr error + result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivatePublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) + return iErr + }) + return result, err +} + +// InsertUserRoomPublicKey inserts a new user room key for the given user and room. +// Returns the newly inserted public key or an existing public key. If there is +// an error talking to the database, returns that error. +func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + + var iErr error + result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) + return iErr + }) + return result, err +} + +// SelectUserRoomPrivateKey queries the users room private key. +// If no key exists, returns no key and no error. Otherwise returns +// the key and a database error, if any. +func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return nil + } + + key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) + if !errors.Is(sErr, sql.ErrNoRows) { + return sErr + } + return nil + }) + return +} + +// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID +func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { + result = make(map[spec.RoomID]map[string]string, len(publicKeys)) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + + // map all roomIDs to roomNIDs + query := make(map[types.RoomNID][]ed25519.PublicKey) + rooms := make(map[types.RoomNID]spec.RoomID) + for roomID, keys := range publicKeys { + roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String()) + if !ok { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String()) + continue + } + roomNID = roomInfo.RoomNID + } + + query[roomNID] = keys + rooms[roomNID] = roomID + } + + // get the user room key pars + userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, query) + if sErr != nil { + return sErr + } + nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap)) + for _, nid := range userRoomKeyPairMap { + nids = append(nids, nid.EventStateKeyNID) + } + // get the userIDs + nidMap, seErr := d.EventStateKeys(ctx, nids) + if seErr != nil { + return seErr + } + + // build the result map (roomID -> map publicKey -> userID) + for publicKey, userRoomKeyPair := range userRoomKeyPairMap { + userID := nidMap[userRoomKeyPair.EventStateKeyNID] + roomID := rooms[userRoomKeyPair.RoomNID] + resMap, exists := result[roomID] + if !exists { + resMap = map[string]string{} + } + resMap[publicKey] = userID + result[roomID] = resMap + } + + return nil + }) + return result, err +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 941e84802..4fa451bcc 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -2,11 +2,15 @@ package shared_test import ( "context" + "crypto/ed25519" "testing" "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" + ed255192 "golang.org/x/crypto/ed25519" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" @@ -23,41 +27,62 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat connStr, clearDB := test.PrepareDBConnectionString(t, dbType) dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} - db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) + writer := sqlutil.NewExclusiveWriter() + db, err := sqlutil.Open(dbOpts, writer) assert.NoError(t, err) var membershipTable tables.Membership var stateKeyTable tables.EventStateKeys + var userRoomKeys tables.UserRoomKeys + var roomsTable tables.Rooms switch dbType { case test.DBTypePostgres: + err = postgres.CreateRoomsTable(db) + assert.NoError(t, err) err = postgres.CreateEventStateKeysTable(db) assert.NoError(t, err) err = postgres.CreateMembershipTable(db) assert.NoError(t, err) + err = postgres.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + roomsTable, err = postgres.PrepareRoomsTable(db) + assert.NoError(t, err) membershipTable, err = postgres.PrepareMembershipTable(db) assert.NoError(t, err) stateKeyTable, err = postgres.PrepareEventStateKeysTable(db) + assert.NoError(t, err) + userRoomKeys, err = postgres.PrepareUserRoomKeysTable(db) case test.DBTypeSQLite: + err = sqlite3.CreateRoomsTable(db) + assert.NoError(t, err) err = sqlite3.CreateEventStateKeysTable(db) assert.NoError(t, err) err = sqlite3.CreateMembershipTable(db) assert.NoError(t, err) + err = sqlite3.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + roomsTable, err = sqlite3.PrepareRoomsTable(db) + assert.NoError(t, err) membershipTable, err = sqlite3.PrepareMembershipTable(db) assert.NoError(t, err) stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db) + assert.NoError(t, err) + userRoomKeys, err = sqlite3.PrepareUserRoomKeysTable(db) } assert.NoError(t, err) cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) - evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache} + evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache, Writer: writer} return &shared.Database{ - DB: db, - EventDatabase: evDb, - MembershipTable: membershipTable, - Writer: sqlutil.NewExclusiveWriter(), - Cache: cache, + DB: db, + EventDatabase: evDb, + MembershipTable: membershipTable, + UserRoomKeyTable: userRoomKeys, + RoomsTable: roomsTable, + Writer: writer, + Cache: cache, }, func() { clearDB() err = db.Close() @@ -97,3 +122,80 @@ func Test_GetLeftUsers(t *testing.T) { assert.ElementsMatch(t, expectedUserIDs, leftUsers) }) } + +func TestUserRoomKeys(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + userID, err := spec.NewUserID(alice.ID, true) + assert.NoError(t, err) + roomID, err := spec.NewRoomID(room.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + // create a room NID so we can query the room + _, err = db.RoomsTable.InsertRoomNID(ctx, nil, roomID.String(), gomatrixserverlib.RoomVersionV10) + assert.NoError(t, err) + doesNotExist, err := spec.NewRoomID("!doesnotexist:localhost") + assert.NoError(t, err) + _, err = db.RoomsTable.InsertRoomNID(ctx, nil, doesNotExist.String(), gomatrixserverlib.RoomVersionV10) + assert.NoError(t, err) + + _, key, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + + gotKey, err := db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // again, this shouldn't result in an error, but return the existing key + _, key2, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotKey, err = db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) + assert.NoError(t, err) + assert.Equal(t, key, gotKey) + + // Key doesn't exist, we shouldn't get anything back + assert.NoError(t, err) + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) + assert.NoError(t, err) + assert.Nil(t, gotKey) + + queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ + *roomID: {key.Public().(ed25519.PublicKey)}, + } + + userIDs, err := db.SelectUserIDsForPublicKeys(ctx, queryUserIDs) + assert.NoError(t, err) + wantKeys := map[spec.RoomID]map[string]string{ + *roomID: { + string(key.Public().(ed25519.PublicKey)): userID.String(), + }, + } + assert.Equal(t, wantKeys, userIDs) + + // insert key that came in over federation + var gotPublicKey, key4 ed255192.PublicKey + key4, _, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *doesNotExist, key4) + assert.NoError(t, err) + assert.Equal(t, key4, gotPublicKey) + + // test invalid room + reallyDoesNotExist, err := spec.NewRoomID("!reallydoesnotexist:localhost") + assert.NoError(t, err) + _, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4) + assert.Error(t, err) + _, err = db.InsertUserRoomPrivatePublicKey(context.Background(), *userID, *reallyDoesNotExist, key) + assert.Error(t, err) + }) +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 6ab427a84..ef51a5b08 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -138,6 +138,9 @@ func (d *Database) create(db *sql.DB) error { if err := CreateRedactionsTable(db); err != nil { return err } + if err := CreateUserRoomKeysTable(db); err != nil { + return err + } return nil } @@ -199,6 +202,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + userRoomKeys, err := PrepareUserRoomKeysTable(db) + if err != nil { + return err + } d.Database = shared.Database{ DB: db, @@ -224,6 +231,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room PublishedTable: published, GetRoomUpdaterFn: d.GetRoomUpdater, Purge: purge, + UserRoomKeyTable: userRoomKeys, } return nil } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go new file mode 100644 index 000000000..8af57ea0e --- /dev/null +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -0,0 +1,146 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "crypto/ed25519" + "database/sql" + "errors" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const userRoomKeysSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( + user_nid INTEGER NOT NULL, + room_nid INTEGER NOT NULL, + pseudo_id_key TEXT NULL, -- may be null for users not local to the server + pseudo_id_pub_key TEXT NOT NULL, + CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) +); +` + +const insertUserRoomKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) + ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key + RETURNING (pseudo_id_key) +` + +const insertUserRoomPublicKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) + ON CONFLICT DO UPDATE SET pseudo_id_pub_key = $3 + RETURNING (pseudo_id_pub_key) +` + +const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + +const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` + +type userRoomKeysStatements struct { + insertUserRoomPrivateKeyStmt *sql.Stmt + insertUserRoomPublicKeyStmt *sql.Stmt + selectUserRoomKeyStmt *sql.Stmt + //selectUserNIDsStmt *sql.Stmt //prepared at runtime +} + +func CreateUserRoomKeysTable(db *sql.DB) error { + _, err := db.Exec(userRoomKeysSchema) + return err +} + +func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { + s := &userRoomKeysStatements{} + return s, sqlutil.StatementList{ + {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, + {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, + {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime + }.Prepare(db) +} + +func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PrivateKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt) + var result ed25519.PrivateKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + +func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { + + roomNIDs := make([]any, 0, len(senderKeys)) + var senders []any + for roomNID := range senderKeys { + roomNIDs = append(roomNIDs, roomNID) + + for _, key := range senderKeys[roomNID] { + senders = append(senders, []byte(key)) + } + } + + selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1) + selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs + + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, err + } + + params := append(roomNIDs, senders...) + + stmt := sqlutil.TxStmt(txn, selectStmt) + defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + result := make(map[string]types.UserRoomKeyPair, len(params)) + var publicKey []byte + userRoomKeyPair := types.UserRoomKeyPair{} + for rows.Next() { + if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { + return nil, err + } + result[string(publicKey)] = userRoomKeyPair + } + return result, rows.Err() +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 333483b32..cd0e51686 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -2,6 +2,7 @@ package tables import ( "context" + "crypto/ed25519" "database/sql" "errors" @@ -184,6 +185,19 @@ type Purge interface { ) error } +type UserRoomKeys interface { + // InsertUserRoomPrivatePublicKey inserts the given private key as well as the public key for it. This should be used + // when creating keys locally. + InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) + // InsertUserRoomPublicKey inserts the given public key, this should be used for users NOT local to this server + InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) + // SelectUserRoomPrivateKey selects the private key for the given user and room combination + SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) + // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. + // If a senderKey can't be found, it is omitted in the result. + BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) +} + // StrippedEvent represents a stripped event for returning extracted content values. type StrippedEvent struct { RoomID string diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go new file mode 100644 index 000000000..284309481 --- /dev/null +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -0,0 +1,115 @@ +package tables_test + +import ( + "context" + "crypto/ed25519" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" + ed255192 "golang.org/x/crypto/ed25519" +) + +func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, db *sql.DB, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareUserRoomKeysTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareUserRoomKeysTable(db) + } + assert.NoError(t, err) + + return tab, db, close +} + +func TestUserRoomKeysTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := mustCreateUserRoomKeysTable(t, dbType) + defer close() + userNID := types.EventStateKeyNID(1) + roomNID := types.RoomNID(1) + _, key, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + + err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + var gotKey, key2, key3 ed25519.PrivateKey + gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // again, this shouldn't result in an error, but return the existing key + _, key2, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // add another user + _, key3, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + userNID2 := types.EventStateKeyNID(2) + _, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID2, roomNID, key3) + assert.NoError(t, err) + + gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) + assert.NoError(t, err) + assert.Equal(t, key, gotKey) + + // try to update an existing key, this should only be done for users NOT on this homeserver + var gotPubKey ed25519.PublicKey + gotPubKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, roomNID, key2.Public().(ed25519.PublicKey)) + assert.NoError(t, err) + assert.Equal(t, key2.Public(), gotPubKey) + + // Key doesn't exist + gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, gotKey) + + // query user NIDs for senderKeys + var gotKeys map[string]types.UserRoomKeyPair + query := map[types.RoomNID][]ed25519.PublicKey{ + roomNID: {key2.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, + types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist + } + gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query) + assert.NoError(t, err) + assert.NotNil(t, gotKeys) + + wantKeys := map[string]types.UserRoomKeyPair{ + string(key2.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID}, + string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2}, + } + assert.Equal(t, wantKeys, gotKeys) + + // insert key that came in over federation + var gotPublicKey, key4 ed255192.PublicKey + key4, _, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotPublicKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, 2, key4) + assert.NoError(t, err) + assert.Equal(t, key4, gotPublicKey) + + return nil + }) + assert.NoError(t, err) + + }) +} diff --git a/roomserver/types/types.go b/roomserver/types/types.go index f57978ad5..45a3e25fc 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -44,6 +44,11 @@ type EventMetadata struct { RoomNID RoomNID } +type UserRoomKeyPair struct { + RoomNID RoomNID + EventStateKeyNID EventStateKeyNID +} + // StateSnapshotNID is a numeric ID for the state at an event. type StateSnapshotNID int64 From 77d9e4e93dd01f6baa82bd6236850c1007346cac Mon Sep 17 00:00:00 2001 From: devonh Date: Mon, 12 Jun 2023 11:19:25 +0000 Subject: [PATCH 20/35] Cleanup remaining statekey usage for senderIDs (#3106) --- clientapi/routing/account_data.go | 10 +- clientapi/routing/aliases.go | 9 +- clientapi/routing/createroom.go | 1 + clientapi/routing/directory.go | 33 ++-- clientapi/routing/leaveroom.go | 10 +- clientapi/routing/membership.go | 147 ++++++++++++------ clientapi/routing/redaction.go | 34 ++-- clientapi/routing/sendtyping.go | 10 +- clientapi/routing/server_notices.go | 13 +- clientapi/routing/state.go | 53 +++++-- clientapi/routing/upgrade_room.go | 10 +- federationapi/routing/eventauth.go | 2 +- federationapi/routing/events.go | 12 +- federationapi/routing/state.go | 2 +- go.mod | 2 +- go.sum | 4 +- roomserver/api/api.go | 21 +-- roomserver/api/output.go | 6 +- roomserver/api/perform.go | 4 +- roomserver/api/query.go | 20 +-- roomserver/auth/auth.go | 14 +- roomserver/auth/auth_test.go | 12 +- roomserver/internal/helpers/helpers.go | 37 +++-- roomserver/internal/helpers/helpers_test.go | 5 +- roomserver/internal/input/input_events.go | 12 +- roomserver/internal/input/input_membership.go | 21 ++- roomserver/internal/perform/perform_admin.go | 6 +- .../internal/perform/perform_backfill.go | 2 +- .../internal/perform/perform_create_room.go | 15 +- roomserver/internal/perform/perform_invite.go | 8 +- roomserver/internal/perform/perform_join.go | 35 ++--- roomserver/internal/perform/perform_leave.go | 77 ++++----- .../internal/perform/perform_upgrade.go | 116 +++++--------- roomserver/internal/query/query.go | 70 +++++---- roomserver/roomserver_test.go | 19 +-- roomserver/storage/interface.go | 2 +- roomserver/storage/shared/storage.go | 7 +- setup/mscs/msc2836/msc2836.go | 11 +- setup/mscs/msc2836/msc2836_test.go | 6 +- syncapi/consumers/roomserver.go | 29 +++- syncapi/internal/history_visibility.go | 14 +- syncapi/internal/keychange.go | 16 +- syncapi/internal/keychange_test.go | 4 + syncapi/notifier/notifier.go | 45 +++--- syncapi/notifier/notifier_test.go | 22 ++- syncapi/routing/context.go | 18 ++- syncapi/routing/getevent.go | 11 +- syncapi/routing/memberships.go | 13 +- syncapi/routing/messages.go | 6 +- syncapi/routing/relations.go | 11 +- syncapi/routing/search.go | 11 +- syncapi/storage/shared/storage_consumer.go | 16 +- syncapi/storage/shared/storage_sync.go | 4 +- syncapi/streams/stream_invite.go | 11 +- syncapi/streams/stream_pdu.go | 12 +- syncapi/syncapi.go | 2 +- syncapi/synctypes/clientevent.go | 35 ++++- syncapi/synctypes/clientevent_test.go | 6 +- syncapi/types/types.go | 4 +- syncapi/types/types_test.go | 8 +- userapi/consumers/roomserver.go | 36 ++++- userapi/util/notify_test.go | 3 +- 62 files changed, 760 insertions(+), 455 deletions(-) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 7eacf9cc9..81afc3b13 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -145,8 +145,16 @@ func SaveReadMarker( userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, ) util.JSONResponse { + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("userID for this device is invalid"), + } + } + // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index f6603be8b..2d6b72d3e 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -55,9 +55,16 @@ func GetAliases( visibility = content.HistoryVisibility } if visibility != spec.WorldReadable { + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } queryReq := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *deviceUserID, } var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 799fc7976..320f236cb 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -224,6 +224,7 @@ func createRoom( PrivateKey: privateKey, EventTime: evTime, } + roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req) if createRes != nil { return *createRes diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 034296f45..f01e24eca 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -314,7 +314,22 @@ func SetVisibility( req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID) + deviceUserID, err := spec.NewUserID(dev.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("userID for this device is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("failed to find senderID for this user"), + } + } + + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } @@ -327,7 +342,7 @@ func SetVisibility( }}, } var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse - err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) + err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) if err != nil || len(queryEventsRes.StateEvents) == 0 { util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") return util.JSONResponse{ @@ -338,20 +353,6 @@ func SetVisibility( // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) - fullUserID, err := spec.NewUserID(dev.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to change visibility"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to change visibility"), - } - } if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index fbf148264..7e8c066eb 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -29,10 +29,18 @@ func LeaveRoomByID( rsAPI roomserverAPI.ClientRoomserverAPI, roomID string, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("device userID is invalid"), + } + } + // Prepare to ask the roomserver to perform the room join. leaveReq := roomserverAPI.PerformLeaveRequest{ RoomID: roomID, - UserID: device.UserID, + Leaver: *userID, } leaveRes := roomserverAPI.PerformLeaveResponse{} diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 78829bec9..03e85edbf 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -57,7 +57,22 @@ func SendBan( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -66,20 +81,6 @@ func SendBan( if errRes != nil { return *errRes } - fullUserID, err := spec.NewUserID(device.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), - } - } allowedToBan := pl.UserLevel(senderID) >= pl.Ban if !allowedToBan { return util.JSONResponse{ @@ -147,7 +148,22 @@ func SendKick( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -156,20 +172,6 @@ func SendKick( if errRes != nil { return *errRes } - fullUserID, err := spec.NewUserID(device.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), - } - } allowedToKick := pl.UserLevel(senderID) >= pl.Kick if !allowedToKick { return util.JSONResponse{ @@ -178,10 +180,17 @@ func SendKick( } } + bodyUserID, err := spec.NewUserID(body.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("body userID is invalid"), + } + } var queryRes roomserverAPI.QueryMembershipForUserResponse err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: body.UserID, + UserID: *bodyUserID, }, &queryRes) if err != nil { return util.ErrorResponse(err) @@ -213,15 +222,30 @@ func SendUnban( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } + bodyUserID, err := spec.NewUserID(body.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("body userID is invalid"), + } + } var queryRes roomserverAPI.QueryMembershipForUserResponse - err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ + err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: body.UserID, + UserID: *bodyUserID, }, &queryRes) if err != nil { return util.ErrorResponse(err) @@ -272,7 +296,15 @@ func SendInvite( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -340,17 +372,18 @@ func sendInvite( func buildMembershipEventDirect( ctx context.Context, - targetUserID, reason string, userDisplayName, userAvatarURL string, - sender string, senderDomain spec.ServerName, + targetSenderID spec.SenderID, reason string, userDisplayName, userAvatarURL string, + sender spec.SenderID, senderDomain spec.ServerName, membership, roomID string, isDirect bool, keyID gomatrixserverlib.KeyID, privateKey ed25519.PrivateKey, evTime time.Time, rsAPI roomserverAPI.ClientRoomserverAPI, ) (*types.HeaderedEvent, error) { + targetSenderString := string(targetSenderID) proto := gomatrixserverlib.ProtoEvent{ - SenderID: sender, + SenderID: string(sender), RoomID: roomID, Type: "m.room.member", - StateKey: &targetUserID, + StateKey: &targetSenderString, } content := gomatrixserverlib.MemberContent{ @@ -391,8 +424,25 @@ func buildMembershipEvent( return nil, err } - return buildMembershipEventDirect(ctx, targetUserID, reason, profile.DisplayName, profile.AvatarURL, - device.UserID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + if err != nil { + return nil, err + } + + targetID, err := spec.NewUserID(targetUserID, true) + if err != nil { + return nil, err + } + targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID) + if err != nil { + return nil, err + } + return buildMembershipEventDirect(ctx, targetSenderID, reason, profile.DisplayName, profile.AvatarURL, + senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) } // loadProfile lookups the profile of a given user from the database and returns @@ -490,7 +540,7 @@ func checkAndProcessThreepid( return } -func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse { +func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID spec.UserID, roomID string) *util.JSONResponse { var membershipRes roomserverAPI.QueryMembershipForUserResponse err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, @@ -518,12 +568,21 @@ func SendForget( ) util.JSONResponse { ctx := req.Context() logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) + + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + var membershipRes roomserverAPI.QueryMembershipForUserResponse membershipReq := roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *deviceUserID, } - err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) if err != nil { logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") return util.JSONResponse{ diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 22474fc08..da48e84de 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -47,7 +47,22 @@ func SendRedaction( txnID *string, txnCache *transactions.Cache, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, userIDErr := spec.NewUserID(device.UserID, true) + if userIDErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + if queryErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } @@ -73,25 +88,10 @@ func SendRedaction( } } - fullUserID, userIDErr := spec.NewUserID(device.UserID, true) - if userIDErr != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to redact"), - } - } - senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if queryErr != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to redact"), - } - } - // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey + allowedToRedact := ev.SenderID() == senderID if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index c5b29297a..979bced3b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -43,8 +43,16 @@ func SendTyping( } } + deviceUserID, err := spec.NewUserID(userID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 06714ed1f..7006ced46 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -52,6 +52,7 @@ type sendServerNoticeRequest struct { StateKey string `json:"state_key,omitempty"` } +// nolint:gocyclo // SendServerNotice sends a message to a specific user. It can only be invoked by an admin. func SendServerNotice( req *http.Request, @@ -187,9 +188,17 @@ func SendServerNotice( } } else { // we've found a room in common, check the membership + deviceUserID, err := spec.NewUserID(r.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + roomID = commonRooms[0] membershipRes := api.QueryMembershipForUserResponse{} - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") return util.JSONResponse{ @@ -234,7 +243,7 @@ func SendServerNotice( ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{ - &types.HeaderedEvent{PDU: e}, + {PDU: e}, }, device.UserDomain(), cfgClient.Matrix.ServerName, diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 13f308998..e3a209b6e 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -99,9 +99,17 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a if !worldReadable { // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UserID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Device UserID is invalid"), + } + } + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") @@ -140,14 +148,11 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // use the result of the previous QueryLatestEventsAndState response // to find the state event, if provided. for _, ev := range stateRes.StateEvents { - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) - if err == nil && userID != nil { - sender = *userID - } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), + synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, ev), ) } } else { @@ -172,9 +177,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a if err == nil && userID != nil { sender = *userID } + + sk := ev.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk), ) } } @@ -259,11 +273,19 @@ func OnIncomingStateTypeRequest( // membershipRes will only be populated if the room is not world-readable. var membershipRes api.QueryMembershipForUserResponse if !worldReadable { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UserID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Device UserID is invalid"), + } + } // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") @@ -344,13 +366,10 @@ func OnIncomingStateTypeRequest( } } - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), + ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, event), } var res interface{} diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go index a0b280789..03c0230e6 100644 --- a/clientapi/routing/upgrade_room.go +++ b/clientapi/routing/upgrade_room.go @@ -59,7 +59,15 @@ func UpgradeRoom( } } - newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, device.UserID, gomatrixserverlib.RoomVersion(r.NewVersion)) + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("device UserID is invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, *userID, gomatrixserverlib.RoomVersion(r.NewVersion)) switch e := err.(type) { case nil: case roomserverAPI.ErrNotAllowed: diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index ca279ac22..c26aa2f15 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -45,7 +45,7 @@ func GetEventAuth( if event.RoomID() != roomID { return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } - resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) if resErr != nil { return *resErr } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index 196a54db1..d3f0e81c3 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -35,10 +35,6 @@ func GetEvent( eventID string, origin spec.ServerName, ) util.JSONResponse { - err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) - if err != nil { - return *err - } // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string, // which results in `QueryEventsByID` to first get the event and use that to determine the roomID. event, err := fetchEvent(ctx, rsAPI, "", eventID) @@ -46,6 +42,11 @@ func GetEvent( return *err } + err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) + if err != nil { + return *err + } + return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{ Origin: origin, OriginServerTS: spec.AsTimestamp(time.Now()), @@ -62,8 +63,9 @@ func allowedToSeeEvent( origin spec.ServerName, rsAPI api.FederationRoomserverAPI, eventID string, + roomID string, ) *util.JSONResponse { - allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID) + allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID, roomID) if err != nil { resErr := util.ErrorResponse(err) return &resErr diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index fa0e9351e..11ad1ebfc 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -116,7 +116,7 @@ func getState( if event.RoomID() != roomID { return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } - resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) if resErr != nil { return nil, nil, resErr } diff --git a/go.mod b/go.mod index 3621428c3..2fbae3148 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d + github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 1ee0261f6..ef8c298ab 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 8c2cbd6b2..bafde91c9 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -34,11 +34,11 @@ func (e ErrNotAllowed) Error() string { type RestrictedJoinAPI interface { CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) - InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) - RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) + InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) + RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error - UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) + UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, senderID spec.SenderID) (bool, error) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) } @@ -191,7 +191,7 @@ type ClientRoomserverAPI interface { PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) // PerformRoomUpgrade upgrades a room to a newer version - PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) + PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminPurgeRoom(ctx context.Context, roomID string) error @@ -228,6 +228,7 @@ type FederationRoomserverAPI interface { // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error + QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error @@ -238,15 +239,13 @@ type FederationRoomserverAPI interface { // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate // the state and auth chain to return. QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error - // Query if we think we're still in a room. - QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error // Query missing events for a room from roomserver QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) + QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error - QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) + QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error @@ -254,12 +253,6 @@ type FederationRoomserverAPI interface { // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error - CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) - InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) - QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) - UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) - LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) - IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) StateQuerier() gomatrixserverlib.StateQuerier } diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 16b504957..852b64206 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -215,8 +215,10 @@ type OutputNewInviteEvent struct { type OutputRetireInviteEvent struct { // The ID of the "m.room.member" invite event. EventID string - // The target user ID of the "m.room.member" invite event that was retired. - TargetUserID string + // The room ID of the "m.room.member" invite event. + RoomID string + // The target sender ID of the "m.room.member" invite event that was retired. + TargetSenderID spec.SenderID // Optional event ID of the event that replaced the invite. // This can be empty if the invite was rejected locally and we were unable // to reach the server that originally sent the invite. diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 6cbaf5b19..b466b7ba8 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -41,8 +41,8 @@ type PerformJoinRequest struct { } type PerformLeaveRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` + RoomID string + Leaver spec.UserID } type PerformLeaveResponse struct { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index d79dcebbb..684a5b0e3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -113,9 +113,9 @@ type QueryEventsByIDResponse struct { // QueryMembershipForUserRequest is a request to QueryMembership type QueryMembershipForUserRequest struct { // ID of the room to fetch membership from - RoomID string `json:"room_id"` + RoomID string // ID of the user for whom membership is requested - UserID string `json:"user_id"` + UserID spec.UserID } // QueryMembershipForUserResponse is a response to QueryMembership @@ -145,7 +145,7 @@ type QueryMembershipsForRoomRequest struct { // Optional - ID of the user sending the request, for checking if the // user is allowed to see the memberships. If not specified then all // room memberships will be returned. - Sender string `json:"sender"` + SenderID spec.SenderID `json:"sender"` } // QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom @@ -448,11 +448,11 @@ func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.Ro return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) } -func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { - return rq.Roomserver.InvitePending(ctx, roomID, userID) +func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) { + return rq.Roomserver.InvitePending(ctx, roomID, senderID) } -func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { +func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID) if err != nil || roomInfo == nil || roomInfo.IsStub() { return nil, err @@ -468,7 +468,7 @@ func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID sp return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) } - userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") return nil, fmt.Errorf("InternalServerError: %w", err) @@ -492,12 +492,8 @@ type MembershipQuerier struct { } func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { - req := QueryMembershipForUserRequest{ - RoomID: roomID.String(), - UserID: string(senderID), - } res := QueryMembershipForUserResponse{} - err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) + err := mq.Roomserver.QueryMembershipForSenderID(ctx, roomID, senderID, &res) membership := "" if err == nil { diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index b6168d38b..ba10a4332 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -13,6 +13,9 @@ package auth import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) @@ -22,6 +25,7 @@ import ( // IsServerAllowed returns true if the server is allowed to see events in the room // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 func IsServerAllowed( + ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, serverCurrentlyInRoom bool, authEvents []gomatrixserverlib.PDU, @@ -37,7 +41,7 @@ func IsServerAllowed( return true } // 2. If the user's membership was join, allow. - joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Join) + joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join) if joinedUserExists { return true } @@ -46,7 +50,7 @@ func IsServerAllowed( return true } // 4. If the user's membership was invite, and the history_visibility was set to invited, allow. - invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Invite) + invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite) if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { return true } @@ -70,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver return visibility } -func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { +func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { for _, ev := range authEvents { if ev.Type() != spec.MRoomMember { continue @@ -85,12 +89,12 @@ func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []go continue } - _, domain, err := gomatrixserverlib.SplitID('@', *stateKey) + userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) if err != nil { continue } - if domain == serverName { + if userID.Domain() == serverName { return true } } diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go index e3eea5d8b..192d9e5da 100644 --- a/roomserver/auth/auth_test.go +++ b/roomserver/auth/auth_test.go @@ -1,13 +1,23 @@ package auth import ( + "context" "testing" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) +type FakeStorageDB struct { + storage.RoomDatabase +} + +func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func TestIsServerAllowed(t *testing.T) { alice := test.NewUser(t) @@ -77,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) { authEvents = append(authEvents, ev.PDU) } - if got := IsServerAllowed(tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { + if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want) } }) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 95397cd5e..263cb9f85 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "sort" - "strings" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -55,9 +54,10 @@ func UpdateToInviteMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: spec.Join, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } @@ -94,13 +94,13 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam for i := range events { gmslEvents[i] = events[i].PDU } - return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, spec.Join), nil + return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil } func IsInvitePending( ctx context.Context, db storage.Database, - roomID, userID string, -) (bool, string, string, gomatrixserverlib.PDU, error) { + roomID string, senderID spec.SenderID, +) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) { // Look up the room NID for the supplied room ID. info, err := db.RoomInfo(ctx, roomID) if err != nil { @@ -111,13 +111,13 @@ func IsInvitePending( } // Look up the state key NID for the supplied user ID. - targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) + targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)}) if err != nil { return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) } - targetUserNID, targetUserFound := targetUserNIDs[userID] + targetUserNID, targetUserFound := targetUserNIDs[string(senderID)] if !targetUserFound { - return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) + return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs) } // Let's see if we have an event active for the user in the room. If @@ -156,7 +156,7 @@ func IsInvitePending( event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false) - return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err + return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err } // GetMembershipsAtState filters the state events to @@ -264,7 +264,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, ) (bool, error) { stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) switch err { @@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent( case tables.OptimisationNotSupportedError: // The database engine didn't support this optimisation, so fall back to using // the old and slow method - stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName) + stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName) if err != nil { return false, err } @@ -288,11 +288,11 @@ func CheckServerAllowedToSeeEvent( return false, err } } - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil + return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil } func slowGetHistoryVisibilityState( - ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, ) ([]gomatrixserverlib.PDU, error) { roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) @@ -319,8 +319,13 @@ func slowGetHistoryVisibilityState( // then we'll filter it out. This does preserve state keys that // are "" since these will contain history visibility etc. for nid, key := range stateKeys { - if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { - delete(stateKeys, nid) + if key != "" { + userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key)) + if err == nil && userID != nil { + if userID.Domain() != serverName { + delete(stateKeys, nid) + } + } } } @@ -410,7 +415,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index f1896277e..1cef83df7 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/roomserver/types" @@ -58,12 +59,12 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { } // Alice should have no pending invites and should have a NID - pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) + pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, spec.SenderID(alice.ID)) assert.NoError(t, err, "failed to get pending invites") assert.False(t, pendingInvite, "unexpected pending invite") // Bob should have no pending invites and receive a new NID - pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) + pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, spec.SenderID(bob.ID)) assert.NoError(t, err, "failed to get pending invites") assert.False(t, pendingInvite, "unexpected pending invite") }) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 1f273da01..7bb401632 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -842,17 +842,15 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r continue } - // TODO: pseudoIDs: get userID for room using state key (which is now senderID) - localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey())) if err != nil { continue } - // TODO: pseudoIDs: query account by state key (which is now senderID) accountRes := &userAPI.QueryAccountByLocalpartResponse{} if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ - Localpart: localpart, - ServerName: senderDomain, + Localpart: memberUserID.Local(), + ServerName: memberUserID.Domain(), }, accountRes); err != nil { return err } @@ -896,8 +894,8 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r inputEvents = append(inputEvents, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: senderDomain, - SendAsServer: string(senderDomain), + Origin: memberUserID.Domain(), + SendAsServer: string(memberUserID.Domain()), }) prevEvents = []string{event.EventID()} } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 98d7d13b1..09c65dfe9 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -18,7 +18,6 @@ import ( "context" "fmt" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" @@ -72,7 +71,7 @@ func (r *Inputer) updateMemberships( if change.addedEventNID != 0 { ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID) } - if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { + if updates, err = r.updateMembership(ctx, updater, targetUserNID, re, ae, updates); err != nil { return nil, err } } @@ -80,6 +79,7 @@ func (r *Inputer) updateMemberships( } func (r *Inputer) updateMembership( + ctx context.Context, updater *shared.RoomUpdater, targetUserNID types.EventStateKeyNID, remove, add *types.Event, @@ -97,7 +97,7 @@ func (r *Inputer) updateMembership( var targetLocal bool if add != nil { - targetLocal = r.isLocalTarget(add) + targetLocal = r.isLocalTarget(ctx, add) } mu, err := updater.MembershipUpdater(targetUserNID, targetLocal) @@ -136,11 +136,14 @@ func (r *Inputer) updateMembership( } } -func (r *Inputer) isLocalTarget(event *types.Event) bool { +func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { - _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) - isTargetLocalUser = domain == r.ServerName + userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey)) + if err != nil || userID == nil { + return isTargetLocalUser + } + isTargetLocalUser = userID.Domain() == r.ServerName } return isTargetLocalUser } @@ -161,9 +164,10 @@ func updateToJoinMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: spec.Join, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } @@ -187,9 +191,10 @@ func updateToLeaveMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: newMembership, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index eeb1ac406..ec13bff87 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -149,11 +149,11 @@ func (r *Admin) PerformAdminEvacuateUser( ctx context.Context, userID string, ) (affected []string, err error) { - _, domain, err := gomatrixserverlib.SplitID('@', userID) + fullUserID, err := spec.NewUserID(userID, true) if err != nil { return nil, err } - if !r.Cfg.Matrix.IsLocalServerName(domain) { + if !r.Cfg.Matrix.IsLocalServerName(fullUserID.Domain()) { return nil, fmt.Errorf("can only evacuate local users using this endpoint") } @@ -172,7 +172,7 @@ func (r *Admin) PerformAdminEvacuateUser( for _, roomID := range allRooms { leaveReq := &api.PerformLeaveRequest{ RoomID: roomID, - UserID: userID, + Leaver: *fullUserID, } leaveRes := &api.PerformLeaveResponse{} outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 388150936..8e87359a3 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -582,7 +582,7 @@ func joinEventsFromHistoryVisibility( } // Can we see events in the room? - canSeeEvents := auth.IsServerAllowed(thisServer, true, events) + canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events) visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index a3ba20f70..475418aa3 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -63,9 +63,17 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } } - createContent["creator"] = userID.String() + senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + createContent["creator"] = senderID createContent["room_version"] = createRequest.RoomVersion - powerLevelContent := eventutil.InitialPowerLevelsContent(userID.String()) + powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID)) joinRuleContent := gomatrixserverlib.JoinRuleContent{ JoinRule: spec.Invite, } @@ -121,7 +129,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } membershipEvent := gomatrixserverlib.FledglingEvent{ Type: spec.MRoomMember, - StateKey: userID.String(), + StateKey: string(senderID), Content: gomatrixserverlib.MemberContent{ Membership: spec.Join, DisplayName: createRequest.UserDisplayName, @@ -270,7 +278,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo var builtEvents []*types.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) - senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed") return "", &util.JSONResponse{ diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 56ee16065..1440daad4 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -134,12 +134,12 @@ func (r *Inviter) PerformInvite( return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} } - if event.StateKey() == nil { + if event.StateKey() == nil || *event.StateKey() == "" { return fmt.Errorf("invite must be a state event") } - invitedUser, err := spec.NewUserID(*event.StateKey(), true) - if err != nil { - return spec.InvalidParam("The user ID is invalid") + invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if err != nil || invitedUser == nil { + return spec.InvalidParam("Could not find the matching senderID for this user") } isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index d41cc214b..83c3b7c3e 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -162,7 +162,7 @@ func (r *Joiner) performJoinRoomByID( } // Get the domain part of the room ID. - _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) + roomID, err := spec.NewRoomID(req.RoomIDOrAlias) if err != nil { return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)} } @@ -170,8 +170,8 @@ func (r *Joiner) performJoinRoomByID( // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. - if !r.Cfg.Matrix.IsLocalServerName(domain) { - req.ServerNames = append(req.ServerNames, domain) + if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) { + req.ServerNames = append(req.ServerNames, roomID.Domain()) } // Prepare the template for the join event. @@ -203,7 +203,7 @@ func (r *Joiner) performJoinRoomByID( req.Content = map[string]interface{}{} } req.Content["membership"] = spec.Join - if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil { + if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil { return "", "", aerr } else if authorisedVia != "" { req.Content["join_authorised_via_users_server"] = authorisedVia @@ -226,17 +226,17 @@ func (r *Joiner) performJoinRoomByID( // Force a federated join if we're dealing with a pending invite // and we aren't in the room. - isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) + isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) if err == nil && !serverInRoom && isInvitePending { - _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) - if ierr != nil { - return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender) + if queryErr != nil { + return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) } // If we were invited by someone from another server then we can // assume they are in the room so we can join via them. - if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { - req.ServerNames = append(req.ServerNames, inviterDomain) + if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { + req.ServerNames = append(req.ServerNames, inviter.Domain()) forceFederatedJoin = true memberEvent := gjson.Parse(string(inviteEvent.JSON())) // only set unsigned if we've got a content.membership, which we _should_ @@ -298,12 +298,8 @@ func (r *Joiner) performJoinRoomByID( // a member of the room. This is best-effort (as in we won't // fail if we can't find the existing membership) because there // is really no harm in just sending another membership event. - membershipReq := &api.QueryMembershipForUserRequest{ - RoomID: req.RoomIDOrAlias, - UserID: userID.String(), - } membershipRes := &api.QueryMembershipForUserResponse{} - _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) + _ = r.Queryer.QueryMembershipForSenderID(ctx, *roomID, senderID, membershipRes) // If we haven't already joined the room then send an event // into the room changing our membership status. @@ -328,7 +324,7 @@ func (r *Joiner) performJoinRoomByID( // The room doesn't exist locally. If the room ID looks like it should // be ours then this probably means that we've nuked our database at // some point. - if r.Cfg.Matrix.IsLocalServerName(domain) { + if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) { // If there are no more server names to try then give up here. // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. @@ -376,15 +372,12 @@ func (r *Joiner) performFederatedJoinRoomByID( func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( ctx context.Context, joinReq *rsAPI.PerformJoinRequest, + senderID spec.SenderID, ) (string, error) { roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias) if err != nil { return "", err } - userID, err := spec.NewUserID(joinReq.UserID, true) - if err != nil { - return "", err - } - return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID) + return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, senderID) } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 094537f8b..1b23cc1ff 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -53,16 +53,12 @@ func (r *Leaver) PerformLeave( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, ) ([]api.OutputEvent, error) { - _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID) - } - if !r.Cfg.Matrix.IsLocalServerName(domain) { - return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) + if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) { + return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String()) } logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomID, - "user_id": req.UserID, + "user_id": req.Leaver.String(), }) logger.Info("User requested to leave join") if strings.HasPrefix(req.RoomID, "!") { @@ -82,21 +78,26 @@ func (r *Leaver) performLeaveRoomByID( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam ) ([]api.OutputEvent, error) { + leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver) + if err != nil { + return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) + } + // If there's an invite outstanding for the room then respond to // that. - isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) + isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver) if err == nil && isInvitePending { - _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", senderUser) + sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser) + if serr != nil || sender == nil { + return nil, fmt.Errorf("sender %q has no matching userID", senderUser) } - if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { - return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) + if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { + return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, leaver) } // check that this is not a "server notice room" accData := &userapi.QueryAccountDataResponse{} if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ - UserID: req.UserID, + UserID: req.Leaver.String(), RoomID: req.RoomID, DataType: "m.tag", }, accData); err != nil { @@ -127,7 +128,7 @@ func (r *Leaver) performLeaveRoomByID( StateToFetch: []gomatrixserverlib.StateKeyTuple{ { EventType: spec.MRoomMember, - StateKey: req.UserID, + StateKey: string(leaver), }, }, } @@ -141,26 +142,18 @@ func (r *Leaver) performLeaveRoomByID( // Now let's see if the user is in the room. if len(latestRes.StateEvents) == 0 { - return nil, fmt.Errorf("user %q is not a member of room %q", req.UserID, req.RoomID) + return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID) } membership, err := latestRes.StateEvents[0].Membership() if err != nil { return nil, fmt.Errorf("error getting membership: %w", err) } if membership != spec.Join && membership != spec.Invite { - return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership) + return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership) } // Prepare the template for the leave event. - fullUserID, err := spec.NewUserID(req.UserID, true) - if err != nil { - return nil, err - } - senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID) - if err != nil { - return nil, err - } - senderIDString := string(senderID) + senderIDString := string(leaver) proto := gomatrixserverlib.ProtoEvent{ Type: spec.MRoomMember, SenderID: senderIDString, @@ -175,16 +168,13 @@ func (r *Leaver) performLeaveRoomByID( return nil, fmt.Errorf("eb.SetUnsigned: %w", err) } - // Get the sender domain. - senderDomain := fullUserID.Domain() - // We know that the user is in the room at this point so let's build // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. var buildRes rsAPI.QueryLatestEventsAndStateResponse - identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.Domain()) if err != nil { return nil, fmt.Errorf("SigningIdentityFor: %w", err) } @@ -201,8 +191,8 @@ func (r *Leaver) performLeaveRoomByID( { Kind: api.KindNew, Event: event, - Origin: senderDomain, - SendAsServer: string(senderDomain), + Origin: req.Leaver.Domain(), + SendAsServer: string(req.Leaver.Domain()), }, }, } @@ -219,21 +209,17 @@ func (r *Leaver) performFederatedRejectInvite( ctx context.Context, req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam - senderUser, eventID string, + inviteSender spec.UserID, eventID string, + leaver spec.SenderID, ) ([]api.OutputEvent, error) { - _, domain, err := gomatrixserverlib.SplitID('@', senderUser) - if err != nil { - return nil, fmt.Errorf("user ID %q invalid: %w", senderUser, err) - } - // Ask the federation sender to perform a federated leave for us. leaveReq := fsAPI.PerformLeaveRequest{ RoomID: req.RoomID, - UserID: req.UserID, - ServerNames: []spec.ServerName{domain}, + UserID: req.Leaver.String(), + ServerNames: []spec.ServerName{inviteSender.Domain()}, } leaveRes := fsAPI.PerformLeaveResponse{} - if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { + if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { // failures in PerformLeave should NEVER stop us from telling other components like the // sync API that the invite was withdrawn. Otherwise we can end up with stuck invites. util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event") @@ -244,7 +230,7 @@ func (r *Leaver) performFederatedRejectInvite( util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event") } - updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, req.UserID, true, info.RoomVersion) + updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(leaver), true, info.RoomVersion) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") } @@ -267,9 +253,10 @@ func (r *Leaver) performFederatedRejectInvite( { Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ - EventID: eventID, - Membership: "leave", - TargetUserID: req.UserID, + EventID: eventID, + RoomID: req.RoomID, + Membership: "leave", + TargetSenderID: leaver, }, }, }, nil diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 5710352bb..1aaa42c94 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -38,19 +38,15 @@ type Upgrader struct { // PerformRoomUpgrade upgrades a room from one version to another func (r *Upgrader) PerformRoomUpgrade( ctx context.Context, - roomID, userID string, roomVersion gomatrixserverlib.RoomVersion, + roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion, ) (newRoomID string, err error) { return r.performRoomUpgrade(ctx, roomID, userID, roomVersion) } func (r *Upgrader) performRoomUpgrade( ctx context.Context, - roomID, userID string, roomVersion gomatrixserverlib.RoomVersion, + roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion, ) (string, error) { - _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) - if err != nil { - return "", api.ErrNotAllowed{Err: fmt.Errorf("error validating the user ID")} - } evTime := time.Now() // Return an immediate error if the room does not exist @@ -58,14 +54,20 @@ func (r *Upgrader) performRoomUpgrade( return "", err } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") + return "", err + } + // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) - if !r.userIsAuthorized(ctx, userID, roomID) { + if !r.userIsAuthorized(ctx, senderID, roomID) { return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")} } // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) + newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain()) // Get the existing room state for the old room. oldRoomReq := &api.QueryLatestEventsAndStateRequest{ @@ -77,25 +79,25 @@ func (r *Upgrader) performRoomUpgrade( } // Make the tombstone event - tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, userID, roomID, newRoomID) + tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, senderID, userID.Domain(), roomID, newRoomID) if pErr != nil { return "", pErr } // Generate the initial events we need to send into the new room. This includes copied state events and bans // as well as the power level events needed to set up the room - eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, roomVersion, tombstoneEvent) + eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, senderID, roomID, roomVersion, tombstoneEvent) if pErr != nil { return "", pErr } // Send the setup events to the new room - if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, roomVersion, eventsToMake); pErr != nil { + if pErr = r.sendInitialEvents(ctx, evTime, senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil { return "", pErr } // 5. Send the tombstone event to the old room - if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil { + if pErr = r.sendHeaderedEvent(ctx, userID.Domain(), tombstoneEvent, string(userID.Domain())); pErr != nil { return "", pErr } @@ -105,17 +107,17 @@ func (r *Upgrader) performRoomUpgrade( } // If the old room had a canonical alias event, it should be deleted in the old room - if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil { + if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, senderID, userID.Domain(), roomID); pErr != nil { return "", pErr } // 4. Move local aliases to the new room - if pErr = moveLocalAliases(ctx, roomID, newRoomID, userID, r.URSAPI); pErr != nil { + if pErr = moveLocalAliases(ctx, roomID, newRoomID, senderID, userID, r.URSAPI); pErr != nil { return "", pErr } // 6. Restrict power levels in the old room - if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil { + if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, senderID, userID.Domain(), roomID); pErr != nil { return "", pErr } @@ -130,7 +132,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma return oldPowerLevelsEvent.PowerLevels() } -func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error { +func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error { restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) if pErr != nil { return pErr @@ -147,7 +149,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel - restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ + restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{ Type: spec.MRoomPowerLevels, StateKey: "", Content: restrictedPowerLevelContent, @@ -165,7 +167,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T } func moveLocalAliases(ctx context.Context, - roomID, newRoomID, userID string, + roomID, newRoomID string, senderID spec.SenderID, userID spec.UserID, URSAPI api.RoomserverInternalAPI, ) (err error) { @@ -175,14 +177,6 @@ func moveLocalAliases(ctx context.Context, return fmt.Errorf("Failed to get old room aliases: %w", err) } - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return fmt.Errorf("Failed to get userID: %w", err) - } - senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return fmt.Errorf("Failed to get senderID: %w", err) - } for _, alias := range aliasRes.Aliases { removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias} removeAliasRes := api.RemoveRoomAliasResponse{} @@ -190,7 +184,7 @@ func moveLocalAliases(ctx context.Context, return fmt.Errorf("Failed to remove old room alias: %w", err) } - setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} + setAliasReq := api.SetRoomAliasRequest{UserID: userID.String(), Alias: alias, RoomID: newRoomID} setAliasRes := api.SetRoomAliasResponse{} if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { return fmt.Errorf("Failed to set new room alias: %w", err) @@ -199,7 +193,7 @@ func moveLocalAliases(ctx context.Context, return nil } -func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error { +func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error { for _, event := range oldRoom.StateEvents { if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") { continue @@ -217,7 +211,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api } } - emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ + emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{ Type: spec.MRoomCanonicalAlias, Content: map[string]interface{}{}, }) @@ -280,7 +274,7 @@ func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error return nil } -func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, +func (r *Upgrader) userIsAuthorized(ctx context.Context, senderID spec.SenderID, roomID string, ) bool { plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, @@ -295,26 +289,18 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, } // Check for power level required to send tombstone event (marks the current room as obsolete), // if not found, use the StateDefault power level - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return false - } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return false - } return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true) } // nolint:gocyclo -func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) { +func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, senderID spec.SenderID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) { state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents)) for _, event := range oldRoom.StateEvents { if event.StateKey() == nil { // This shouldn't ever happen, but better to be safe than sorry. continue } - if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) { + if event.Type() == spec.MRoomMember && !event.StateKeyEquals(string(senderID)) { // With the exception of bans which we do want to copy, we // should ignore membership events that aren't our own, as event auth will // prevent us from being able to create membership events on behalf of other @@ -330,6 +316,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query } } // skip events that rely on a specific user being present + // TODO: What to do here for pseudoIDs? It's checking non-member events for state keys with userIDs. sKey := *event.StateKey() if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" { continue @@ -340,10 +327,10 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // The following events are ones that we are going to override manually // in the following section. override := map[gomatrixserverlib.StateKeyTuple]struct{}{ - {EventType: spec.MRoomCreate, StateKey: ""}: {}, - {EventType: spec.MRoomMember, StateKey: userID}: {}, - {EventType: spec.MRoomPowerLevels, StateKey: ""}: {}, - {EventType: spec.MRoomJoinRules, StateKey: ""}: {}, + {EventType: spec.MRoomCreate, StateKey: ""}: {}, + {EventType: spec.MRoomMember, StateKey: string(senderID)}: {}, + {EventType: spec.MRoomPowerLevels, StateKey: ""}: {}, + {EventType: spec.MRoomJoinRules, StateKey: ""}: {}, } // The overridden events are essential events that must be present in the @@ -355,7 +342,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query } oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}] - oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: userID}] + oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: string(senderID)}] oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}] oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}] @@ -364,7 +351,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // in the create event (such as for the room types MSC). newCreateContent := map[string]interface{}{} _ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent) - newCreateContent["creator"] = userID + newCreateContent["creator"] = string(senderID) newCreateContent["room_version"] = newVersion newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{ EventID: tombstoneEvent.EventID(), @@ -385,7 +372,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query newMembershipContent["membership"] = spec.Join newMembershipEvent := gomatrixserverlib.FledglingEvent{ Type: spec.MRoomMember, - StateKey: userID, + StateKey: string(senderID), Content: newMembershipContent, } @@ -400,14 +387,6 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query return nil, fmt.Errorf("Power level event content was invalid") } - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return nil, err - } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return nil, err - } tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID) // Now do the join rules event, same as the create and membership @@ -470,21 +449,13 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query return eventsToMake, nil } -func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { +func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { var err error var builtEvents []*types.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) for i, e := range eventsToMake { depth := i + 1 // depth starts at 1 - fullUserID, userIDErr := spec.NewUserID(userID, true) - if userIDErr != nil { - return userIDErr - } - senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID) - if queryErr != nil { - return queryErr - } proto := gomatrixserverlib.ProtoEvent{ SenderID: string(senderID), RoomID: newRoomID, @@ -549,7 +520,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user func (r *Upgrader) makeTombstoneEvent( ctx context.Context, evTime time.Time, - userID, roomID, newRoomID string, + senderID spec.SenderID, senderDomain spec.ServerName, roomID, newRoomID string, ) (*types.HeaderedEvent, error) { content := map[string]interface{}{ "body": "This room has been replaced", @@ -559,30 +530,21 @@ func (r *Upgrader) makeTombstoneEvent( Type: "m.room.tombstone", Content: content, } - return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event) + return r.makeHeaderedEvent(ctx, evTime, senderID, senderDomain, roomID, event) } -func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return nil, err - } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return nil, err - } +func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, senderID spec.SenderID, senderDomain spec.ServerName, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { proto := gomatrixserverlib.ProtoEvent{ SenderID: string(senderID), RoomID: roomID, Type: event.Type, StateKey: &event.StateKey, } - err = proto.SetContent(event.Content) + err := proto.SetContent(event.Content) if err != nil { return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err) } // Get the sender domain. - senderDomain := fullUserID.Domain() identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) if err != nil { return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ae2b7cf57..caea6b526 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -48,7 +48,7 @@ type Queryer struct { Cfg *config.Dendrite } -func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { +func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { roomInfo, err := r.QueryRoomInfo(ctx, roomID) if err != nil || roomInfo == nil || roomInfo.IsStub() { return nil, err @@ -64,7 +64,7 @@ func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) } - userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") return nil, fmt.Errorf("InternalServerError: %w", err) @@ -220,13 +220,14 @@ func (r *Queryer) QueryEventsByID( return nil } -// QueryMembershipForUser implements api.RoomserverInternalAPI -func (r *Queryer) QueryMembershipForUser( +// QueryMembershipForSenderID implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForSenderID( ctx context.Context, - request *api.QueryMembershipForUserRequest, + roomID spec.RoomID, + senderID spec.SenderID, response *api.QueryMembershipForUserResponse, ) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) + info, err := r.DB.RoomInfo(ctx, roomID.String()) if err != nil { return err } @@ -236,7 +237,7 @@ func (r *Queryer) QueryMembershipForUser( } response.RoomExists = true - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID) if err != nil { return err } @@ -264,6 +265,24 @@ func (r *Queryer) QueryMembershipForUser( return err } +// QueryMembershipForUser implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID) + if err != nil { + return err + } + + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) +} + // QueryMembershipAtEvent returns the known memberships at a given event. // If the state before an event is not known, an empty list will be returned // for that event instead. @@ -373,7 +392,7 @@ func (r *Queryer) QueryMembershipsForRoom( // If no sender is specified then we will just return the entire // set of memberships for the room, regardless of whether a specific // user is allowed to see them or not. - if request.Sender == "" { + if request.SenderID == "" { var events []types.Event var eventNIDs []types.EventNID eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly) @@ -388,18 +407,15 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - sender := spec.UserID{} - userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if queryErr == nil && userID != nil { - sender = *userID - } - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) + clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil } - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID) if err != nil { return err } @@ -442,12 +458,9 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - sender := spec.UserID{} - userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) + clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -489,6 +502,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( ctx context.Context, serverName spec.ServerName, eventID string, + roomID string, ) (allowed bool, err error) { events, err := r.DB.EventNIDs(ctx, []string{eventID}) if err != nil { @@ -518,7 +532,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( } return helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, eventID, serverName, isInRoom, + ctx, r.DB, info, roomID, eventID, serverName, isInRoom, ) } @@ -909,8 +923,8 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq return nil } -func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { - pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String()) +func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) { + pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID) return pending, err } @@ -926,8 +940,8 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve return res, err } -func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) { - _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String()) +func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) { + _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID) return isIn, err } @@ -957,7 +971,7 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse } // nolint:gocyclo -func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { +func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { // Look up if we know anything about the room. If it doesn't exist // or is a stub entry then we can't do anything. roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) @@ -972,7 +986,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return "", err } - return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) + return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) } func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 5e6ba7d4e..90c94bbce 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -722,7 +722,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { roomID, _ := spec.NewRoomID(testRoom.ID) userID, _ := spec.NewUserID(bob.ID, true) - got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID) + got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, spec.SenderID(userID.String())) if tc.wantError && err == nil { t.Fatal("expected error, got none") } @@ -821,17 +821,6 @@ func TestUpgrade(t *testing.T) { validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) wantNewRoom bool }{ - { - name: "invalid userID", - upgradeUser: "!notvalid:test", - roomFunc: func(rsAPI api.RoomserverInternalAPI) string { - room := test.NewRoom(t, alice) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { - t.Errorf("failed to send events: %v", err) - } - return room.ID - }, - }, { name: "invalid roomID", upgradeUser: alice.ID, @@ -1049,7 +1038,11 @@ func TestUpgrade(t *testing.T) { } roomID := tc.roomFunc(rsAPI) - newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, tc.upgradeUser, version.DefaultRoomVersion()) + userID, err := spec.NewUserID(tc.upgradeUser, true) + if err != nil { + t.Fatalf("upgrade userID is invalid") + } + newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, *userID, version.DefaultRoomVersion()) if err != nil && tc.wantNewRoom { t.Fatal(err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 2d27d7999..ef4463781 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -131,7 +131,7 @@ type Database interface { // in this room, along a boolean set to true if the user is still in this room, // false if not. // Returns an error if there was a problem talking to the database. - GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) + GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) // Lookup the membership event numeric IDs for all user that are or have // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cb12b3f57..85a1ba7a1 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -490,10 +490,10 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { }) } -func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { +func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { var requestSenderUserNID types.EventStateKeyNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) + requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID)) return err }) if err != nil { @@ -936,6 +936,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) ( return roomVersion, err } +// nolint:gocyclo // MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec: // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events @@ -1014,7 +1015,7 @@ func (d *EventDatabase) MaybeRedactEvent( switch { case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1Domain == sender2Domain: + case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain: // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 47eb544ea..d3f1c9dd2 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -154,7 +154,7 @@ type reqCtx struct { rsAPI roomserver.RoomserverInternalAPI db Database req *EventRelationshipRequest - userID string + userID spec.UserID roomVersion gomatrixserverlib.RoomVersion // federated request args @@ -173,10 +173,17 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: 400, + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } rc := reqCtx{ ctx: req.Context(), req: relation, - userID: device.UserID, + userID: *userID, rsAPI: rsAPI, fsAPI: fsAPI, isFederatedRequest: false, diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 551d7ad45..e32d6a9f2 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -529,6 +529,10 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str return spec.NewUserID(string(senderID), true) } +func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { + return spec.SenderID(userID.String()), nil +} + func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { for _, eventID := range req.EventIDs { ev := r.events[eventID] @@ -540,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver } func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { - rooms := r.userToJoinedRooms[req.UserID] + rooms := r.userToJoinedRooms[req.UserID.String()] for _, roomID := range rooms { if roomID == req.RoomID { res.IsInRoom = true diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 8a2a0b1f6..c5f2db9c8 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -373,7 +373,15 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst // TODO: check that it's a join and not a profile change (means unmarshalling prev_content) if membership == spec.Join { // check it's a local join - if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil { + if ev.StateKey() == nil { + return sp, fmt.Errorf("unexpected nil state_key") + } + + userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + if err != nil || userID == nil { + return sp, fmt.Errorf("failed getting userID for sender: %w", err) + } + if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) { return sp, nil } @@ -395,9 +403,15 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( if msg.Event.StateKey() == nil { return } - if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil { + + userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey())) + if err != nil || userID == nil { return } + if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) { + return + } + pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) if err != nil { sentry.CaptureException(err) @@ -440,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) - s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) + userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID) + if err != nil || userID == nil { + log.WithFields(log.Fields{ + "event_id": msg.EventID, + "sender_id": msg.TargetSenderID, + log.ErrorKey: err, + }).Errorf("failed to find userID for sender") + return + } + s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String()) } func (s *OutputRoomEventConsumer) onNewPeek( diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 7449b4647..ab1a7f83d 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -134,9 +134,17 @@ func ApplyHistoryVisibilityFilter( } } // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules") - if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) { - eventsFiltered = append(eventsFiltered, ev) - continue + + user, err := spec.NewUserID(userID, true) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user) + if err == nil { + if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) { + eventsFiltered = append(eventsFiltered, ev) + continue + } } // Always allow history evVis events on boundaries. This is done // by setting the effective evVis to the least restrictive diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index ad5935cdc..f4b6ace59 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -169,12 +169,16 @@ func TrackChangedUsers( if err != nil { return nil, nil, err } - for _, state := range stateRes.Rooms { + for roomID, state := range stateRes.Rooms { for tuple, membership := range state { if membership != spec.Join { continue } - queryRes.UserIDsToCount[tuple.StateKey]-- + user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) + if queryErr != nil || user == nil { + continue + } + queryRes.UserIDsToCount[user.String()]-- } } @@ -211,14 +215,18 @@ func TrackChangedUsers( if err != nil { return nil, left, err } - for _, state := range stateRes.Rooms { + for roomID, state := range stateRes.Rooms { for tuple, membership := range state { if membership != spec.Join { continue } // new user who we weren't previously sharing rooms with if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { - changed = append(changed, tuple.StateKey) // changed is returned + user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) + if err != nil || user == nil { + continue + } + changed = append(changed, user.String()) // changed is returned } } } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 23c2ecbaa..efa641475 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -64,6 +64,10 @@ type mockRoomserverAPI struct { roomIDToJoinedMembers map[string][]string } +func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + // QueryRoomsForUser retrieves a list of room IDs matching the given query. func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { return nil diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index f76456859..4ee7c8605 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" @@ -36,7 +37,8 @@ import ( // the event, but the token has already advanced by the time they fetch it, resulting // in missed events. type Notifier struct { - lock *sync.RWMutex + lock *sync.RWMutex + rsAPI api.SyncRoomserverAPI // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine roomIDToJoinedUsers map[string]*userIDSet // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine @@ -55,8 +57,9 @@ type Notifier struct { // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier() *Notifier { +func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier { return &Notifier{ + rsAPI: rsAPI, roomIDToJoinedUsers: make(map[string]*userIDSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet), userDeviceStreams: make(map[string]map[string]*UserDeviceStream), @@ -104,26 +107,32 @@ func (n *Notifier) OnNewEvent( peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { - targetUserID := *ev.StateKey() - membership, err := ev.Membership() + targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey())) if err != nil { log.WithError(err).WithField("event_id", ev.EventID()).Errorf( - "Notifier.OnNewEvent: Failed to unmarshal member event", + "Notifier.OnNewEvent: Failed to find the userID for this event", ) } else { - // Keep the joined user map up-to-date - switch membership { - case spec.Invite: - usersToNotify = append(usersToNotify, targetUserID) - case spec.Join: - // Manually append the new user's ID so they get notified - // along all members in the room - usersToNotify = append(usersToNotify, targetUserID) - n._addJoinedUser(ev.RoomID(), targetUserID) - case spec.Leave: - fallthrough - case spec.Ban: - n._removeJoinedUser(ev.RoomID(), targetUserID) + membership, err := ev.Membership() + if err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: Failed to unmarshal member event", + ) + } else { + // Keep the joined user map up-to-date + switch membership { + case spec.Invite: + usersToNotify = append(usersToNotify, targetUserID.String()) + case spec.Join: + // Manually append the new user's ID so they get notified + // along all members in the room + usersToNotify = append(usersToNotify, targetUserID.String()) + n._addJoinedUser(ev.RoomID(), targetUserID.String()) + case spec.Leave: + fallthrough + case spec.Ban: + n._removeJoinedUser(ev.RoomID(), targetUserID.String()) + } } } } diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index 36577a0ee..7076f7134 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -22,9 +22,11 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { } } +type TestRoomServer struct{ api.SyncRoomserverAPI } + +func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) if err != nil { @@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) { // Test that new events to a joined room unblocks the request. func TestNewEventAndJoinedToRoom(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { } func TestCorrectStream(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) stream := lockedFetchUserStream(n, bob, bobDev) if stream.UserID != bob { @@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) { } func TestCorrectStreamWakeup(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) awoken := make(chan string) @@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) { // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) { // Test that all blocked requests get woken up on a new event. func TestMultipleRequestWakeup(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // listen as bob. Make bob leave room. Make alice send event to room. // Make sure alice gets woken up only and not bob as well. - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 7fb88faaa..55fd3c5a2 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -85,9 +85,16 @@ func Context( *filter.Rooms = append(*filter.Rooms, roomID) } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} - membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} + membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID} if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { logrus.WithError(err).Error("unable to query membership") return util.JSONResponse{ @@ -217,12 +224,9 @@ func Context( } } - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender) + ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, requestedEvent) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 63df7e837..de790e5cd 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -106,8 +106,17 @@ func GetEvent( if err == nil && senderUserID != nil { sender = *senderUserID } + + sk := events[0].StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 813167a5e..cf6769ba4 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -59,14 +59,21 @@ func GetMemberships( syncDB storage.Database, rsAPI api.SyncRoomserverAPI, joinedOnly bool, membership, notMembership *string, at string, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } queryReq := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, } var queryRes api.QueryMembershipForUserResponse - if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") + if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil { + util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 781fd53e7..6784a27bd 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -296,9 +296,13 @@ func OnIncomingMessagesRequest( } func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return resp, err + } req := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: userID, + UserID: *fullUserID, } if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { return api.QueryMembershipForUserResponse{}, err diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index f21c684c8..6efa065a9 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -119,9 +119,18 @@ func Relations( if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk), ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index add50b181..7d9182f47 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } results = append(results, Result{ Context: SearchContextResponse{ Start: startToken.String(), @@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 5bd3b1f01..799e3d166 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -507,8 +507,20 @@ func (d *Database) CleanSendToDeviceUpdates( // getMembershipFromEvent returns the value of content.membership iff the event is a state event // with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. -func getMembershipFromEvent(ev gomatrixserverlib.PDU, userID string) (string, string) { - if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { +func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) { + if ev.StateKey() == nil || *ev.StateKey() == "" { + return "", "" + } + fullUser, err := spec.NewUserID(userID, true) + if err != nil { + return "", "" + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser) + if err != nil { + return "", "" + } + + if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) { return "", "" } membership, err := ev.Membership() diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index df9613850..8e79b71df 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -430,7 +430,7 @@ func (d *DatabaseTransaction) GetStateDeltas( for _, ev := range stateStreamEvents { // Look for our membership in the state events and skip over any // membership events that are not related to us. - membership, prevMembership := getMembershipFromEvent(ev.PDU, userID) + membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI) if membership == "" { continue } @@ -556,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( for roomID, stateStreamEvents := range state { for _, ev := range stateStreamEvents { - if membership, _ := getMembershipFromEvent(ev.PDU, userID); membership != "" { + if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" { if membership != spec.Join { // We've already added full state for all joined rooms above. deltas[roomID] = types.StateDelta{ Membership: membership, diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index a8b0a7b66..3a5badd92 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync( user = *sender } + sk := inviteEvent.StateKey() + if sk != nil && *sk != "" { + skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + // skip ignored user events if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, user) + ir := types.NewInviteResponse(inviteEvent, user, sk) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index d214980bd..f728d4aea 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -605,13 +605,17 @@ func (p *PDUStreamProvider) lazyLoadMembers( // If this is a gapped incremental sync, we still want this membership isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list - stateKey := *event.StateKey() - if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID { + userID := "" + stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey())) + if err == nil && stateKeyUserID != nil { + userID = stateKeyUserID.String() + } + if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID { newStateEvents = append(newStateEvents, event) if !stateFilter.IncludeRedundantMembers { - p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID()) } - delete(timelineUsers, stateKey) + delete(timelineUsers, userID) } } else { newStateEvents = append(newStateEvents, event) diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index ecbe05dd8..64a4af757 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -60,7 +60,7 @@ func AddPublicRoutes( } eduCache := caching.NewTypingCache() - notifier := notifier.NewNotifier() + notifier := notifier.NewNotifier(rsAPI) streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 66fb1d01f..358a0c971 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, if err == nil && userID != nil { sender = *userID } - evs = append(evs, ToClientEvent(se, format, sender)) + + sk := se.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + evs = append(evs, ToClientEvent(se, format, sender, sk)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent { ce := ClientEvent{ Content: spec.RawJSON(se.Content()), Sender: sender.String(), Type: se.Type(), - StateKey: se.StateKey(), + StateKey: stateKey, Unsigned: spec.RawJSON(se.Unsigned()), OriginServerTS: se.OriginServerTS(), EventID: se.EventID(), @@ -77,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp } return ce } + +// ToClientEvent converts a single server event to a client event. +// It provides default logic for event.SenderID & event.StateKey -> userID conversions. +func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { + sender := spec.UserID{} + userID, err := userIDQuery(event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + return ToClientEvent(event, FormatAll, sender, sk) +} diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 341795081..63c65b2af 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if err != nil { t.Fatalf("failed to create userID: %s", err) } - ce := ToClientEvent(ev, FormatAll, *userID) + sk := "" + ce := ToClientEvent(ev, FormatAll, *userID, &sk) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create userID: %s", err) } - ce := ToClientEvent(ev, FormatSync, *userID) + sk := "" + ce := ToClientEvent(ev, FormatSync, *userID, &sk) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index a3dc7f54b..cb3c362d5 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -539,7 +539,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index a79ce5417..c1b7f70bd 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) { if err != nil { t.Fatal(err) } + skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true) + if err != nil { + t.Fatal(err) + } + skString := skUserID.String() + sk := &skString - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index df507eb26..b2dc477aa 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -306,7 +306,16 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst if queryErr == nil && userID != nil { sender = *userID } - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender) + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if queryErr == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { @@ -539,12 +548,21 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if queryErr == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it @@ -792,10 +810,20 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes Type: event.Type(), }, } - if mem, err := event.Membership(); err == nil { + if mem, memberErr := event.Membership(); memberErr == nil { req.Notification.Membership = mem } - if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { + userID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName), true) + if err != nil { + logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) + return nil, err + } + localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID) + if err != nil { + logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) + return nil, err + } + if event.StateKey() != nil && *event.StateKey() == string(localSender) { req.Notification.UserIsTarget = true } } diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 27dd373c2..3017069bc 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -104,8 +104,9 @@ func TestNotifyUserCountsAsync(t *testing.T) { if err != nil { t.Error(err) } + sk := "" if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk), }); err != nil { t.Error(err) } From 82b73a49068771168ed52351f7be3b033692be4a Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 13 Jun 2023 12:50:22 +0200 Subject: [PATCH 21/35] Add `sender_key` to ClientEvent (#3110) --- syncapi/synctypes/clientevent.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 358a0c971..433be39f8 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -37,6 +37,7 @@ type ClientEvent struct { OriginServerTS spec.Timestamp `json:"origin_server_ts,omitempty"` // OriginServerTS is omitted on receipt events RoomID string `json:"room_id,omitempty"` // RoomID is omitted on /sync responses Sender string `json:"sender,omitempty"` // Sender is omitted on receipt events + SenderKey spec.SenderID `json:"sender_key,omitempty"` // The SenderKey for events in pseudo ID rooms StateKey *string `json:"state_key,omitempty"` Type string `json:"type"` Unsigned spec.RawJSON `json:"unsigned,omitempty"` @@ -84,6 +85,9 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp if format == FormatAll { ce.RoomID = se.RoomID() } + if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + ce.SenderKey = se.SenderID() + } return ce } From 2c87972a3a84be400e5c69e2e5a727f21b4e457e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 13 Jun 2023 14:19:31 +0200 Subject: [PATCH 22/35] Create user room key if needed (#3108) --- roomserver/api/api.go | 4 +++ roomserver/internal/api.go | 21 ++++++++++++++++ .../internal/perform/perform_create_room.go | 25 ++++++++++++++++++- roomserver/internal/perform/perform_invite.go | 8 ++++++ roomserver/internal/perform/perform_join.go | 9 +++++++ roomserver/storage/shared/storage.go | 2 +- 6 files changed, 67 insertions(+), 2 deletions(-) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index bafde91c9..fec28841e 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -2,6 +2,7 @@ package api import ( "context" + "crypto/ed25519" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -66,6 +67,9 @@ type RoomserverInternalAPI interface { req *QueryAuthChainRequest, res *QueryAuthChainResponse, ) error + + // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. + GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) } type InputRoomEventsAPI interface { diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 35b7383a9..4bcd3f3ed 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -2,6 +2,7 @@ package internal import ( "context" + "crypto/ed25519" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -270,3 +271,23 @@ func (r *RoomserverInternalAPI) PerformForget( ) error { return r.Forgetter.PerformForget(ctx, req, resp) } + +// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. +func (r *RoomserverInternalAPI) GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) { + key, err := r.DB.SelectUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + return nil, err + } + // no key found, create one + if len(key) == 0 { + _, key, err = ed25519.GenerateKey(nil) + if err != nil { + return nil, err + } + key, err = r.DB.InsertUserRoomPrivatePublicKey(ctx, userID, roomID, key) + if err != nil { + return nil, err + } + } + return key, nil +} diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 475418aa3..121b257ed 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -354,7 +354,30 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo SendAsServer: api.DoNotSendToOtherServers, }) } - if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs, false); err != nil { + + // first send the `m.room.create` event, so we have a roomNID + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[:1], false); err != nil { + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // create user room key if needed + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + // send the remaining events + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 1440daad4..cc2c5c191 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -183,6 +183,14 @@ func (r *Inviter) PerformInvite( inviteEvent = event } + // if we invited a local user, we can also create a user room key, if it doesn't exist yet. + if isTargetLocal && event.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *invitedUser, *validRoomID) + if err != nil { + return fmt.Errorf("failed to get user room private key: %w", err) + } + } + // Send the invite event to the roomserver input stream. This will // notify existing users in the room about the invite, update the // membership table and ensure that the event is ready and available diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 83c3b7c3e..74ed87c74 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -293,6 +293,15 @@ func (r *Joiner) performJoinRoomByID( switch err.(type) { case nil: + // create user room key if needed + if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if err != nil { + logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", fmt.Errorf("failed to get user room private key: %w", err) + } + } + // The room join is local. Send the new join event into the // roomserver. First of all check that the user isn't already // a member of the room. This is best-effort (as in we won't diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 85a1ba7a1..d7ca3cefd 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1686,7 +1686,7 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use return rErr } if roomInfo == nil { - return nil + return eventutil.ErrRoomNoExists{} } key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) From 7a2e325d1014d76188b47a011730a42443f3c174 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 13 Jun 2023 16:28:41 +0200 Subject: [PATCH 23/35] Add `AssignRoomNID` to pre-assign roomNIDs (#3111) --- roomserver/storage/interface.go | 2 ++ roomserver/storage/shared/storage.go | 20 ++++++++++++++++++++ roomserver/storage/shared/storage_test.go | 22 ++++++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index ef4463781..7787d9f85 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -31,6 +31,7 @@ type Database interface { UserRoomKeys // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) @@ -212,6 +213,7 @@ type UserRoomKeys interface { type RoomDatabase interface { EventDatabase UserRoomKeys + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index d7ca3cefd..bda51da81 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -662,6 +662,26 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) } +func (d *Database) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) { + // This should already be checked, let's check it anyway. + _, err = gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + return 0, err + } + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err = d.assignRoomNID(ctx, txn, roomID.String(), roomVersion) + if err != nil { + return err + } + return nil + }) + if err != nil { + return 0, err + } + // Not setting caches, as assignRoomNID already does this + return roomNID, err +} + // GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (roomInfo *types.RoomInfo, err error) { // Get the default room version. If the client doesn't supply a room_version diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 4fa451bcc..581d83ee4 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" @@ -199,3 +200,24 @@ func TestUserRoomKeys(t *testing.T) { assert.Error(t, err) }) } + +func TestAssignRoomNID(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + roomID, err := spec.NewRoomID(room.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + nid, err := db.AssignRoomNID(ctx, *roomID, room.Version) + assert.NoError(t, err) + assert.Greater(t, nid, types.EventNID(0)) + + _, err = db.AssignRoomNID(ctx, spec.RoomID{}, "notaroomversion") + assert.Error(t, err) + }) +} From e4665979bfbe006368d55189f074e456fe19b198 Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 14 Jun 2023 14:23:46 +0000 Subject: [PATCH 24/35] Merge SenderID & Per Room User Key work (#3109) --- appservice/consumers/roomserver.go | 8 +- clientapi/routing/directory.go | 23 ++- clientapi/routing/membership.go | 26 ++- clientapi/routing/profile.go | 6 +- clientapi/routing/redaction.go | 15 +- clientapi/routing/sendevent.go | 13 +- clientapi/routing/state.go | 13 +- clientapi/threepid/invites.go | 6 +- cmd/resolve-state/main.go | 12 +- federationapi/federationapi_test.go | 4 +- federationapi/internal/perform.go | 31 ++-- federationapi/routing/invite.go | 4 +- federationapi/routing/join.go | 6 +- federationapi/routing/leave.go | 13 +- federationapi/routing/threepid.go | 11 +- go.mod | 10 +- go.sum | 20 +-- internal/pushrules/evaluate.go | 6 +- internal/pushrules/evaluate_test.go | 8 +- internal/transactionrequest.go | 2 +- internal/transactionrequest_test.go | 4 +- roomserver/api/api.go | 8 +- roomserver/auth/auth.go | 16 +- roomserver/auth/auth_test.go | 10 +- roomserver/internal/alias.go | 12 +- roomserver/internal/api.go | 1 + roomserver/internal/helpers/auth.go | 8 +- roomserver/internal/helpers/helpers.go | 38 +++-- roomserver/internal/input/input_events.go | 35 ++-- .../internal/input/input_events_test.go | 2 +- .../internal/input/input_latest_events.go | 2 +- roomserver/internal/input/input_membership.go | 6 +- roomserver/internal/input/input_missing.go | 22 +-- roomserver/internal/perform/perform_admin.go | 20 ++- .../internal/perform/perform_backfill.go | 38 +++-- .../internal/perform/perform_create_room.go | 37 ++--- .../internal/perform/perform_inbound_peek.go | 2 +- roomserver/internal/perform/perform_invite.go | 31 ++-- roomserver/internal/perform/perform_join.go | 153 ++++++++++-------- roomserver/internal/perform/perform_leave.go | 10 +- .../internal/perform/perform_upgrade.go | 10 +- roomserver/internal/query/query.go | 81 +++++++--- roomserver/roomserver_test.go | 5 +- roomserver/state/state.go | 14 +- roomserver/storage/interface.go | 10 +- .../storage/postgres/user_room_keys_table.go | 19 +++ roomserver/storage/shared/room_updater.go | 5 - roomserver/storage/shared/storage.go | 55 +++++-- roomserver/storage/shared/storage_test.go | 7 +- .../storage/sqlite3/user_room_keys_table.go | 19 +++ roomserver/storage/tables/interface.go | 2 + .../tables/user_room_keys_table_test.go | 7 + setup/mscs/msc2836/msc2836.go | 2 +- setup/mscs/msc2836/msc2836_test.go | 4 +- syncapi/consumers/roomserver.go | 23 ++- syncapi/internal/history_visibility.go | 6 +- syncapi/internal/keychange.go | 12 +- syncapi/internal/keychange_test.go | 2 +- syncapi/notifier/notifier.go | 9 +- syncapi/notifier/notifier_test.go | 2 +- syncapi/routing/context.go | 10 +- syncapi/routing/getevent.go | 18 ++- syncapi/routing/memberships.go | 12 +- syncapi/routing/messages.go | 4 +- syncapi/routing/relations.go | 9 +- syncapi/routing/search.go | 24 ++- syncapi/routing/search_test.go | 2 +- syncapi/storage/shared/storage_consumer.go | 15 +- syncapi/streams/stream_invite.go | 8 +- syncapi/streams/stream_pdu.go | 24 +-- syncapi/syncapi_test.go | 2 +- syncapi/synctypes/clientevent.go | 16 +- test/room.go | 2 +- userapi/consumers/roomserver.go | 38 ++++- userapi/consumers/roomserver_test.go | 10 +- 75 files changed, 801 insertions(+), 379 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index ff124514e..1877de37a 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) sendEvents( // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), }, @@ -236,7 +236,11 @@ func (s *appserviceState) backoffAndPause(err error) error { // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { user := "" - userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return false + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err == nil { user = userID.String() } diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index f01e24eca..d9129d1bd 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -233,11 +233,18 @@ func RemoveLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) + validRoomID, err := spec.NewRoomID(roomIDRes.RoomID) if err != nil { return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"}, + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), } } @@ -321,7 +328,15 @@ func SetVisibility( JSON: spec.BadJSON("userID for this device is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 03e85edbf..bafc37b67 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -64,7 +64,14 @@ func SendBan( JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -155,7 +162,14 @@ func SendKick( JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -428,7 +442,11 @@ func buildMembershipEvent( if err != nil { return nil, err } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) if err != nil { return nil, err } @@ -437,7 +455,7 @@ func buildMembershipEvent( if err != nil { return nil, err } - targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID) + targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID) if err != nil { return nil, err } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index e734e2e4f..8a44834e1 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -368,7 +368,11 @@ func buildMembershipEvents( return nil, err } for _, roomID := range roomIDs { - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index da48e84de..42f029395 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -54,7 +54,14 @@ func SendRedaction( JSON: spec.Forbidden("userID doesn't have power level to redact"), } } - senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if queryErr != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -103,8 +110,8 @@ func SendRedaction( JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), } } - pl, err := plEvent.PowerLevels() - if err != nil { + pl, plErr := plEvent.PowerLevels() + if plErr != nil { return util.JSONResponse{ Code: 403, JSON: spec.Forbidden( @@ -134,7 +141,7 @@ func SendRedaction( Type: spec.MRoomRedaction, Redacts: eventID, } - err := proto.SetContent(r) + err = proto.SetContent(r) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") return util.JSONResponse{ diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 4d0a9f24a..d51a570de 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -273,7 +273,14 @@ func generateSendEvent( JSON: spec.BadJSON("Bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusNotFound, @@ -344,8 +351,8 @@ func generateSendEvent( stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID) }); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index e3a209b6e..f53cb3013 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -150,7 +150,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateRes.StateEvents { stateEvents = append( stateEvents, - synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, ev), ) @@ -173,14 +173,19 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } for _, ev := range stateAfterRes.StateEvents { sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + evRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid") + continue + } + userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID()) if err == nil && userID != nil { sender = *userID } sk := ev.StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -367,7 +372,7 @@ func OnIncomingStateTypeRequest( } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + ClientEvent: synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, event), } diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index e7ffbac2b..d15cc6d46 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -359,7 +359,11 @@ func emit3PIDInviteEvent( if err != nil { return err } - sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) if err != nil { return err } diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 15c87f1a8..3ffcac9e6 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -11,11 +11,13 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -66,10 +68,14 @@ func main() { panic(err) } + natsInstance := &jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, + natsInstance, caching.NewRistrettoCache(128*1024*1024, time.Hour, true), false) + roomInfo := &types.RoomInfo{ RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), } - stateres := state.NewStateResolution(roomserverDB, roomInfo) + stateres := state.NewStateResolution(roomserverDB, roomInfo, rsAPI) if *difference { if len(snapshotNIDs) != 2 { @@ -183,8 +189,8 @@ func main() { fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 173908437..5d167c0ee 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,11 +36,11 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } -func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } -func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { +func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { return spec.SenderID(userID.String()), nil } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 485b79a03..7f61dba41 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -154,14 +154,9 @@ func (r *FederationInternalAPI) performJoinUsingServer( if err != nil { return err } - senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, roomID, *user) - if err != nil { - return err - } joinInput := gomatrixserverlib.PerformJoinInput{ UserID: user, - SenderID: senderID, RoomID: room, ServerName: serverName, Content: content, @@ -169,12 +164,20 @@ func (r *FederationInternalAPI) performJoinUsingServer( PrivateKey: r.cfg.Matrix.PrivateKey, KeyID: r.cfg.Matrix.KeyID, KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, + SenderIDCreator: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (spec.SenderID, error) { + key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if keyErr != nil { + return "", keyErr + } + + return spec.SenderID(spec.Base64Bytes(key).Encode()), nil + }, } response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) @@ -368,7 +371,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + userIDProvider := func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( @@ -459,7 +462,11 @@ func (r *FederationInternalAPI) PerformLeave( // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" - senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID) + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { return err } @@ -527,7 +534,11 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err != nil { return nil, err } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 5b15f810d..e45209a2f 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,7 +95,7 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } @@ -188,7 +188,7 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index d14801921..7aa50f65a 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -98,7 +98,7 @@ func MakeJoin( Roomserver: rsAPI, } - senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") return util.JSONResponse{ @@ -118,7 +118,7 @@ func MakeJoin( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, RoomQuerier: &roomQuerier, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, BuildEventTemplate: createJoinTemplate, @@ -215,7 +215,7 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 716276bec..5c8dd00f3 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -87,7 +87,7 @@ func MakeLeave( return event, stateEvents, nil } - senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") return util.JSONResponse{ @@ -105,7 +105,7 @@ func MakeLeave( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } @@ -236,7 +236,14 @@ func SendLeave( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Room ID is invalid."), + } + } + sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID()) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 360802de5..42ba8bfe5 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -140,7 +140,14 @@ func ExchangeThirdPartyInvite( } } - userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID)) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Invalid room ID"), + } + } + userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(proto.SenderID)) if err != nil || userID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -150,7 +157,7 @@ func ExchangeThirdPartyInvite( senderDomain := userID.Domain() // Check that the state key is correct. - targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey)) + targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(*proto.StateKey)) if err != nil || targetUserID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/go.mod b/go.mod index 2fbae3148..930db3958 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -42,11 +42,11 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.9.0 + golang.org/x/crypto v0.10.0 golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/sync v0.1.0 - golang.org/x/term v0.8.0 + golang.org/x/term v0.9.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -127,8 +127,8 @@ require ( golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.8.0 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/text v0.10.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.6.0 // indirect google.golang.org/protobuf v1.28.1 // indirect diff --git a/go.sum b/go.sum index ef8c298ab..cf6993938 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1 h1:k75Fy0iQVbDjvddip/x898+BdyopBNAfL1BMNx0awA0= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -511,8 +511,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -669,12 +669,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= +golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -683,8 +683,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index ac7608950..28dea97c4 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -115,7 +115,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati case SenderKind: userID := "" - sender, err := userIDForSender(event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return false, err + } + sender, err := userIDForSender(*validRoomID, event.SenderID()) if err == nil { userID = sender.String() } diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 859d1f8a6..a4ccc3d0f 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -73,7 +73,7 @@ func TestRuleMatches(t *testing.T) { {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, {"emptyContent", ContentKind, emptyRule, `{}`, false}, {"emptyRoom", RoomKind, emptyRule, `{}`, true}, - {"emptySender", SenderKind, emptyRule, `{}`, true}, + {"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true}, {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, {"disabled", OverrideKind, Rule{}, `{}`, false}, @@ -90,8 +90,8 @@ func TestRuleMatches(t *testing.T) { {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, - {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true}, - {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false}, + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index b2929bb5d..5bf7d819c 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index 1d32c8060..ffc1cd89a 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,7 +70,7 @@ type FakeRsAPI struct { bannedFromRoom bool } -func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -642,7 +642,7 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } -func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index fec28841e..e2dd5dd73 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -51,6 +51,7 @@ type RoomserverInternalAPI interface { UserRoomserverAPI FederationRoomserverAPI QuerySenderIDAPI + UserRoomPrivateKeyCreator // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -67,7 +68,9 @@ type RoomserverInternalAPI interface { req *QueryAuthChainRequest, res *QueryAuthChainResponse, ) error +} +type UserRoomPrivateKeyCreator interface { // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) } @@ -81,8 +84,8 @@ type InputRoomEventsAPI interface { } type QuerySenderIDAPI interface { - QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) - QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) + QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) + QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) } // Query the latest events and state for a room from the room server. @@ -228,6 +231,7 @@ type FederationRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI QuerySenderIDAPI + UserRoomPrivateKeyCreator // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index ba10a4332..d6c10cf92 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -15,7 +15,7 @@ package auth import ( "context" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) @@ -25,7 +25,7 @@ import ( // IsServerAllowed returns true if the server is allowed to see events in the room // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 func IsServerAllowed( - ctx context.Context, db storage.RoomDatabase, + ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, serverCurrentlyInRoom bool, authEvents []gomatrixserverlib.PDU, @@ -41,7 +41,7 @@ func IsServerAllowed( return true } // 2. If the user's membership was join, allow. - joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join) + joinedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Join) if joinedUserExists { return true } @@ -50,7 +50,7 @@ func IsServerAllowed( return true } // 4. If the user's membership was invite, and the history_visibility was set to invited, allow. - invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite) + invitedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Invite) if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { return true } @@ -74,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver return visibility } -func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { +func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { for _, ev := range authEvents { if ev.Type() != spec.MRoomMember { continue @@ -89,7 +89,11 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabas continue } - userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + continue + } + userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey)) if err != nil { continue } diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go index 192d9e5da..058361e6e 100644 --- a/roomserver/auth/auth_test.go +++ b/roomserver/auth/auth_test.go @@ -4,17 +4,17 @@ import ( "context" "testing" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) -type FakeStorageDB struct { - storage.RoomDatabase +type FakeQuerier struct { + api.QuerySenderIDAPI } -func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -87,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) { authEvents = append(authEvents, ev.PDU) } - if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { + if got := IsServerAllowed(context.Background(), &FakeQuerier{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want) } }) diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index c950024ad..e6fb73383 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -113,6 +113,7 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID( return nil } +// nolint:gocyclo // RemoveRoomAlias implements alias.RoomserverInternalAPI func (r *RoomserverInternalAPI) RemoveRoomAlias( ctx context.Context, @@ -129,7 +130,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return nil } - sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + + sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID) if err != nil || sender == nil { return fmt.Errorf("r.QueryUserIDForSender: %w", err) } @@ -177,7 +183,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( if request.SenderID != ev.SenderID() { senderID = ev.SenderID() } - sender, err := r.QueryUserIDForSender(ctx, roomID, senderID) + sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID) if err != nil || sender == nil { return err } @@ -206,7 +212,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( } stateRes := &api.QueryLatestEventsAndStateResponse{} - if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { + if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 4bcd3f3ed..7943ae5c0 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -177,6 +177,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio IsLocalServerName: r.Cfg.Global.IsLocalServerName, DB: r.DB, FSAPI: r.fsAPI, + Querier: r.Queryer, KeyRing: r.KeyRing, // Perspective servers are trusted to not lie about server keys, so we will also // prefer these servers when backfilling (assuming they are in the room) rather diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 7782d07d2..89fae244f 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" @@ -36,6 +37,7 @@ func CheckForSoftFail( roomInfo *types.RoomInfo, event *types.HeaderedEvent, stateEventIDs []string, + querier api.QuerySenderIDAPI, ) (bool, error) { rewritesState := len(stateEventIDs) > 1 @@ -49,7 +51,7 @@ func CheckForSoftFail( } else { // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. - roomState := state.NewStateResolution(db, roomInfo) + roomState := state.NewStateResolution(db, roomInfo, querier) authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID()) if err != nil { return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) @@ -76,8 +78,8 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return db.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return querier.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { // return true, nil return true, err diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 263cb9f85..febabf411 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -68,7 +68,7 @@ func UpdateToInviteMembership( // memberships. If the servername is not supplied then the local server will be // checked instead using a faster code path. // TODO: This should probably be replaced by an API call. -func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) { +func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) { info, err := db.RoomInfo(ctx, roomID) if err != nil { return false, err @@ -94,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam for i := range events { gmslEvents[i] = events[i].PDU } - return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil + return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil } func IsInvitePending( @@ -211,8 +211,8 @@ func GetMembershipsAtState( return events, nil } -func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) +func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info, querier) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -229,8 +229,8 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } -func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) +func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info, querier) // Fetch the state as it was when this event was fired return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) } @@ -264,7 +264,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI, ) (bool, error) { stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) switch err { @@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent( case tables.OptimisationNotSupportedError: // The database engine didn't support this optimisation, so fall back to using // the old and slow method - stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName) + stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier) if err != nil { return false, err } @@ -288,13 +288,13 @@ func CheckServerAllowedToSeeEvent( return false, err } } - return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil + return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil } func slowGetHistoryVisibilityState( - ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI, ) ([]gomatrixserverlib.PDU, error) { - roomState := state.NewStateResolution(db, info) + roomState := state.NewStateResolution(db, info, querier) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState( // If the event state key doesn't match the given servername // then we'll filter it out. This does preserve state keys that // are "" since these will contain history visibility etc. + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } for nid, key := range stateKeys { if key != "" { - userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key)) + userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key)) if err == nil && userID != nil { if userID.Domain() != serverName { delete(stateKeys, nid) @@ -349,7 +353,7 @@ func slowGetHistoryVisibilityState( // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, - serverName spec.ServerName, + serverName spec.ServerName, querier api.QuerySenderIDAPI, ) ([]types.EventNID, map[string]struct{}, error) { var resultNIDs []types.EventNID var err error @@ -392,7 +396,7 @@ BFSLoop: // It's nasty that we have to extract the room ID from an event, but many federation requests // only talk in event IDs, no room IDs at all (!!!) ev := events[0] - isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID()) + isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID()) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") } @@ -415,7 +419,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom) + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", @@ -444,7 +448,7 @@ BFSLoop: } func QueryLatestEventsAndState( - ctx context.Context, db storage.Database, + ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { @@ -457,7 +461,7 @@ func QueryLatestEventsAndState( return nil } - roomState := state.NewStateResolution(db, roomInfo) + roomState := state.NewStateResolution(db, roomInfo, querier) response.RoomExists = true response.RoomVersion = roomInfo.RoomVersion diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 7bb401632..aa05d9594 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,7 +128,11 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err != nil { return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) } @@ -282,8 +286,8 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { isRejected = true rejectionErr = err @@ -321,7 +325,7 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindNew && !isCreateEvent { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs) + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs, r.Queryer) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -401,7 +405,7 @@ func (r *Inputer) processRoomEvent( redactedEvent gomatrixserverlib.PDU ) if !isRejected && !isCreateEvent { - resolver := state.NewStateResolution(r.DB, roomInfo) + resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer) redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver) if err != nil { return err @@ -587,8 +591,8 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return @@ -700,8 +704,8 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { continue nextAuthEvent } @@ -718,8 +722,8 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) @@ -783,7 +787,7 @@ func (r *Inputer) calculateAndSetState( return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) } defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - roomState := state.NewStateResolution(updater, roomInfo) + roomState := state.NewStateResolution(updater, roomInfo, r.Queryer) if input.HasState { // We've been told what the state at the event is so we don't need to calculate it. @@ -836,13 +840,18 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r return err } + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + prevEvents := latestRes.LatestEvents for _, memberEvent := range memberEvents { if memberEvent.StateKey() == nil { continue } - memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey())) + memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey())) if err != nil { continue } diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 5f2cd9562..4ee6d2110 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) }); err == nil { t.Fatalf("event should not be allowed, but it was") diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 7a7a021a3..940783e03 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -213,7 +213,7 @@ func (u *latestEventsUpdater) latestState() error { defer trace.EndRegion() var err error - roomState := state.NewStateResolution(u.updater, u.roomInfo) + roomState := state.NewStateResolution(u.updater, u.roomInfo, u.api.Queryer) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 09c65dfe9..c46f8dba1 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -139,7 +139,11 @@ func (r *Inputer) updateMembership( func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { - userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey)) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return isTargetLocalUser + } + userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey)) if err != nil || userID == nil { return isTargetLocalUser } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index f0f974d26..7ee84e4c0 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -383,7 +383,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even defer trace.EndRegion() var res parsedRespState - roomState := state.NewStateResolution(t.db, t.roomInfo) + roomState := state.NewStateResolution(t.db, t.roomInfo, t.inputer.Queryer) stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID}) if err != nil { t.log.WithError(err).Warnf("failed to get state after %s locally", eventID) @@ -473,8 +473,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -482,8 +482,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: @@ -569,8 +569,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } @@ -660,8 +660,8 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + }, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { return nil, err @@ -897,8 +897,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index ec13bff87..12b557f51 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -74,6 +74,10 @@ func (r *Admin) PerformAdminEvacuateRoom( if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { return nil, err } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } prevEvents := latestRes.LatestEvents var senderDomain spec.ServerName @@ -100,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom( PrevEvents: prevEvents, } - userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID)) + userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID)) if err != nil || userID == nil { continue } @@ -264,16 +268,16 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } @@ -293,7 +297,11 @@ func (r *Admin) PerformAdminDownloadState( stateIDs = append(stateIDs, stateEvent.EventID()) } - senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 8e87359a3..533ad25bf 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -42,6 +42,7 @@ type Backfiller struct { DB storage.Database FSAPI federationAPI.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier + Querier api.QuerySenderIDAPI // The servers which should be preferred above other servers when backfilling PreferServers []spec.ServerName @@ -79,7 +80,7 @@ func (r *Backfiller) PerformBackfill( } // Scan the event tree for events to send back. - resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) + resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r.Querier) if err != nil { return err } @@ -113,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub() { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) + requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass @@ -121,8 +122,8 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Querier.QueryUserIDForSender(ctx, roomID, senderID) }, ) // Only return an error if we really couldn't get any events. @@ -135,7 +136,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // persist these new events - auth checks have already been done - roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) + roomNID, backfilledEventMap := persistEvents(ctx, r.DB, r.Querier, events) for _, ev := range backfilledEventMap { // now add state for these events @@ -212,8 +213,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Querier.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") @@ -246,13 +247,14 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom } } util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) - persistEvents(ctx, r.DB, newEvents) + persistEvents(ctx, r.DB, r.Querier, newEvents) } // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { db storage.Database fsAPI federationAPI.RoomserverFederationAPI + querier api.QuerySenderIDAPI virtualHost spec.ServerName isLocalServerName func(spec.ServerName) bool preferServer map[spec.ServerName]bool @@ -268,6 +270,7 @@ type backfillRequester struct { func newBackfillRequester( db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, + querier api.QuerySenderIDAPI, virtualHost spec.ServerName, isLocalServerName func(spec.ServerName) bool, bwExtrems map[string][]string, preferServers []spec.ServerName, @@ -279,6 +282,7 @@ func newBackfillRequester( return &backfillRequester{ db: db, fsAPI: fsAPI, + querier: querier, virtualHost: virtualHost, isLocalServerName: isLocalServerName, eventIDToBeforeStateIDs: make(map[string][]string), @@ -460,14 +464,14 @@ FindSuccessor: return nil } - stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID, b.querier) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, b.querier, info, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -488,7 +492,11 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + continue + } + if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil { serverSet[sender.Domain()] = true } } @@ -554,7 +562,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( - ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, querier api.QuerySenderIDAPI, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { var eventNIDs []types.EventNID @@ -582,7 +590,7 @@ func joinEventsFromHistoryVisibility( } // Can we see events in the room? - canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events) + canSeeEvents := auth.IsServerAllowed(ctx, querier, thisServer, true, events) visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) @@ -597,7 +605,7 @@ func joinEventsFromHistoryVisibility( return evs, visibility, err } -func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) { +func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) { var roomNID types.RoomNID var eventNID types.EventNID backfilledEventMap := make(map[string]types.Event) @@ -639,7 +647,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse continue } - resolver := state.NewStateResolution(db, roomInfo) + resolver := state.NewStateResolution(db, roomInfo, querier) _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver) if err != nil { diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 121b257ed..fd8055e09 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -63,13 +63,20 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } } - senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") - return "", &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, + var senderID spec.SenderID + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + // create user room key if needed + key, keyErr := c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if keyErr != nil { + util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } + senderID = spec.SenderID(spec.Base64Bytes(key).Encode()) + } else { + senderID = spec.SenderID(userID.String()) } createContent["creator"] = senderID createContent["room_version"] = createRequest.RoomVersion @@ -323,8 +330,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return c.DB.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return "", &util.JSONResponse{ @@ -364,18 +371,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - // create user room key if needed - if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { - _, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") - return "", &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - } - // send the remaining events if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") @@ -455,7 +450,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo JSON: spec.InternalServerError{}, } } - inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID) + inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID) if queryErr != nil { util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed") return "", &util.JSONResponse{ diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 3ac0f6f4d..7fbec3710 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek( response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]} // XXX: do we actually need to do a state resolution here? - roomState := state.NewStateResolution(r.DB, info) + roomState := state.NewStateResolution(r.DB, info, r.Inputer.Queryer) var stateEntries []types.StateEntry stateEntries, err = roomState.LoadStateAtSnapshot( diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index cc2c5c191..babd5f812 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -34,6 +34,7 @@ import ( type QueryState struct { storage.Database + querier api.QuerySenderIDAPI } func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) { @@ -46,7 +47,7 @@ func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWant return nil, fmt.Errorf("failed to load RoomInfo: %w", err) } if info != nil { - roomState := state.NewStateResolution(q.Database, info) + roomState := state.NewStateResolution(q.Database, info, q.querier) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID(), stateWanted, ) @@ -98,7 +99,11 @@ func (r *Inviter) ProcessInviteMembership( var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) + validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) + if err != nil { + return nil, err + } + userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } @@ -126,7 +131,12 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + + sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err != nil { return spec.InvalidParam("The sender user ID is invalid") } @@ -137,18 +147,13 @@ func (r *Inviter) PerformInvite( if event.StateKey() == nil || *event.StateKey() == "" { return fmt.Errorf("invite must be a state event") } - invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) if err != nil || invitedUser == nil { return spec.InvalidParam("Could not find the matching senderID for this user") } isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - - invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser) + invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser) if err != nil { return fmt.Errorf("failed looking up senderID for invited user") } @@ -161,9 +166,9 @@ func (r *Inviter) PerformInvite( IsTargetLocal: isTargetLocal, StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, - StateQuerier: &QueryState{r.DB}, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + StateQuerier: &QueryState{r.DB, r.RSAPI}, + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 74ed87c74..5867ee6e0 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -25,6 +25,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -174,44 +175,6 @@ func (r *Joiner) performJoinRoomByID( req.ServerNames = append(req.ServerNames, roomID.Domain()) } - // Prepare the template for the join event. - userID, err := spec.NewUserID(req.UserID, true) - if err != nil { - return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} - } - senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID) - if err != nil { - return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} - } - senderIDString := string(senderID) - userDomain := userID.Domain() - proto := gomatrixserverlib.ProtoEvent{ - Type: spec.MRoomMember, - SenderID: senderIDString, - StateKey: &senderIDString, - RoomID: req.RoomIDOrAlias, - Redacts: "", - } - if err = proto.SetUnsigned(struct{}{}); err != nil { - return "", "", fmt.Errorf("eb.SetUnsigned: %w", err) - } - - // It is possible for the request to include some "content" for the - // event. We'll always overwrite the "membership" key, but the rest, - // like "display_name" or "avatar_url", will be kept if supplied. - if req.Content == nil { - req.Content = map[string]interface{}{} - } - req.Content["membership"] = spec.Join - if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil { - return "", "", aerr - } else if authorisedVia != "" { - req.Content["join_authorised_via_users_server"] = authorisedVia - } - if err = proto.SetContent(req.Content); err != nil { - return "", "", fmt.Errorf("eb.SetContent: %w", err) - } - // Force a federated join if we aren't in the room and we've been // given some server names to try joining by. inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{ @@ -224,29 +187,63 @@ func (r *Joiner) performJoinRoomByID( serverInRoom := inRoomRes.IsInRoom forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom + userID, err := spec.NewUserID(req.UserID, true) + if err != nil { + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} + } + + // Look up the room NID for the supplied room ID. + var senderID spec.SenderID + checkInvitePending := false + info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias) + if err == nil && info != nil { + switch info.RoomVersion { + case gomatrixserverlib.RoomVersionPseudoIDs: + senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID) + if err == nil { + checkInvitePending = true + } else { + // create user room key if needed + key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if keyErr != nil { + util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr) + } + senderID = spec.SenderID(spec.Base64Bytes(key).Encode()) + } + default: + checkInvitePending = true + senderID = spec.SenderID(userID.String()) + } + } + + userDomain := userID.Domain() + // Force a federated join if we're dealing with a pending invite // and we aren't in the room. - isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) - if err == nil && !serverInRoom && isInvitePending { - inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender) - if queryErr != nil { - return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) - } + if checkInvitePending { + isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) + if inviteErr == nil && !serverInRoom && isInvitePending { + inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender) + if queryErr != nil { + return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) + } - // If we were invited by someone from another server then we can - // assume they are in the room so we can join via them. - if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { - req.ServerNames = append(req.ServerNames, inviter.Domain()) - forceFederatedJoin = true - memberEvent := gjson.Parse(string(inviteEvent.JSON())) - // only set unsigned if we've got a content.membership, which we _should_ - if memberEvent.Get("content.membership").Exists() { - req.Unsigned = map[string]interface{}{ - "prev_sender": memberEvent.Get("sender").Str, - "prev_content": map[string]interface{}{ - "is_direct": memberEvent.Get("content.is_direct").Bool(), - "membership": memberEvent.Get("content.membership").Str, - }, + // If we were invited by someone from another server then we can + // assume they are in the room so we can join via them. + if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { + req.ServerNames = append(req.ServerNames, inviter.Domain()) + forceFederatedJoin = true + memberEvent := gjson.Parse(string(inviteEvent.JSON())) + // only set unsigned if we've got a content.membership, which we _should_ + if memberEvent.Get("content.membership").Exists() { + req.Unsigned = map[string]interface{}{ + "prev_sender": memberEvent.Get("sender").Str, + "prev_content": map[string]interface{}{ + "is_direct": memberEvent.Get("content.is_direct").Bool(), + "membership": memberEvent.Get("content.membership").Str, + }, + } } } } @@ -274,6 +271,7 @@ func (r *Joiner) performJoinRoomByID( // If we should do a forced federated join then do that. var joinedVia spec.ServerName if forceFederatedJoin { + // TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) return req.RoomIDOrAlias, joinedVia, err } @@ -289,19 +287,40 @@ func (r *Joiner) performJoinRoomByID( if err != nil { return "", "", fmt.Errorf("error joining local room: %q", err) } + + senderIDString := string(senderID) + + // Prepare the template for the join event. + proto := gomatrixserverlib.ProtoEvent{ + Type: spec.MRoomMember, + SenderID: senderIDString, + StateKey: &senderIDString, + RoomID: req.RoomIDOrAlias, + Redacts: "", + } + if err = proto.SetUnsigned(struct{}{}); err != nil { + return "", "", fmt.Errorf("eb.SetUnsigned: %w", err) + } + + // It is possible for the request to include some "content" for the + // event. We'll always overwrite the "membership" key, but the rest, + // like "display_name" or "avatar_url", will be kept if supplied. + if req.Content == nil { + req.Content = map[string]interface{}{} + } + req.Content["membership"] = spec.Join + if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil { + return "", "", aerr + } else if authorisedVia != "" { + req.Content["join_authorised_via_users_server"] = authorisedVia + } + if err = proto.SetContent(req.Content); err != nil { + return "", "", fmt.Errorf("eb.SetContent: %w", err) + } event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes) switch err.(type) { case nil: - // create user room key if needed - if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { - _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) - if err != nil { - logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") - return "", "", fmt.Errorf("failed to get user room private key: %w", err) - } - } - // The room join is local. Send the new join event into the // roomserver. First of all check that the user isn't already // a member of the room. This is best-effort (as in we won't diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 1b23cc1ff..e1ddb9b50 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -78,7 +78,11 @@ func (r *Leaver) performLeaveRoomByID( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam ) ([]api.OutputEvent, error) { - leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver) + roomID, err := spec.NewRoomID(req.RoomID) + if err != nil { + return nil, err + } + leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver) if err != nil { return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) } @@ -87,7 +91,7 @@ func (r *Leaver) performLeaveRoomByID( // that. isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver) if err == nil && isInvitePending { - sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser) + sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser) if serr != nil || sender == nil { return nil, fmt.Errorf("sender %q has no matching userID", senderUser) } @@ -133,7 +137,7 @@ func (r *Leaver) performLeaveRoomByID( }, } latestRes := api.QueryLatestEventsAndStateResponse{} - if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil { + if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil { return nil, err } if !latestRes.RoomExists { diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 1aaa42c94..32f547dc1 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -54,7 +54,11 @@ func (r *Upgrader) performRoomUpgrade( return "", err } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID) + fullRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return "", err + } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") return "", err @@ -488,7 +492,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send } - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) @@ -569,7 +573,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index caea6b526..19fd456b5 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -16,6 +16,7 @@ package query import ( "context" + "crypto/ed25519" "database/sql" "errors" "fmt" @@ -89,7 +90,7 @@ func (r *Queryer) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response) + return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response) } // QueryStateAfterEvents implements api.RoomserverInternalAPI @@ -106,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents( return nil } - roomState := state.NewStateResolution(r.DB, info) + roomState := state.NewStateResolution(r.DB, info, r) response.RoomExists = true response.RoomVersion = info.RoomVersion @@ -159,8 +160,8 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -271,15 +272,15 @@ func (r *Queryer) QueryMembershipForUser( request *api.QueryMembershipForUserRequest, response *api.QueryMembershipForUserResponse, ) error { - senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID) - if err != nil { - return err - } - roomID, err := spec.NewRoomID(request.RoomID) if err != nil { return err } + senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID) + if err != nil { + return err + } + return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) } @@ -320,7 +321,7 @@ func (r *Queryer) QueryMembershipAtEvent( } response.Membership = make(map[string]*types.HeaderedEvent) - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r) if err != nil { return fmt.Errorf("unable to get state before event: %w", err) } @@ -407,7 +408,7 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.QueryUserIDForSender(ctx, roomID, senderID) }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) @@ -445,7 +446,7 @@ func (r *Queryer) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs) } else { - stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -458,7 +459,7 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.QueryUserIDForSender(ctx, roomID, senderID) }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) @@ -532,7 +533,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( } return helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, roomID, eventID, serverName, isInRoom, + ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r, ) } @@ -573,7 +574,7 @@ func (r *Queryer) QueryMissingEvents( return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID) } - resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) + resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r) if err != nil { return err } @@ -651,8 +652,8 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -673,7 +674,7 @@ func (r *Queryer) QueryStateAndAuthChain( // first bool: is rejected, second bool: state missing func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) { - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(r.DB, roomInfo, r) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { @@ -989,10 +990,46 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) } -func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { - return r.DB.GetSenderIDForUser(ctx, roomID, userID) +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + version, err := r.DB.GetRoomVersion(ctx, roomID.String()) + if err != nil { + return "", err + } + + switch version { + case gomatrixserverlib.RoomVersionPseudoIDs: + key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) + if err != nil { + return "", err + } + return spec.SenderID(spec.Base64Bytes(key).Encode()), nil + default: + return spec.SenderID(userID.String()), nil + } } -func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + userID, err := spec.NewUserID(string(senderID), true) + if err == nil { + return userID, nil + } + + bytes := spec.Base64Bytes{} + err = bytes.Decode(string(senderID)) + if err != nil { + return nil, err + } + queryMap := map[spec.RoomID][]ed25519.PublicKey{roomID: {ed25519.PublicKey(bytes)}} + result, err := r.DB.SelectUserIDsForPublicKeys(ctx, queryMap) + if err != nil { + return nil, err + } + + if userKeys, ok := result[roomID]; ok { + if userID, ok := userKeys[string(senderID)]; ok { + return spec.NewUserID(userID, true) + } + } + + return nil, nil } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 90c94bbce..077957fa1 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -516,6 +516,9 @@ func TestRedaction(t *testing.T) { t.Fatal(err) } + natsInstance := &jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { authEvents := []types.EventNID{} @@ -551,7 +554,7 @@ func TestRedaction(t *testing.T) { } // Calculate the snapshotNID etc. - plResolver := state.NewStateResolution(db, roomInfo) + plResolver := state.NewStateResolution(db, roomInfo, rsAPI) stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false) assert.NoError(t, err) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index b9c5bbc4a..1e776ff6c 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -29,6 +29,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -44,20 +45,21 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) - GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) } type StateResolution struct { db StateResolutionStorage roomInfo *types.RoomInfo events map[types.EventNID]gomatrixserverlib.PDU + Querier api.QuerySenderIDAPI } -func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { +func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo, querier api.QuerySenderIDAPI) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, events: make(map[types.EventNID]gomatrixserverlib.PDU), + Querier: querier, } } @@ -947,8 +949,8 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return v.db.GetUserIDForSender(ctx, roomID, senderID) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) }) // Map from the full events back to numeric state entries. @@ -1061,8 +1063,8 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, - func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return v.db.GetUserIDForSender(ctx, roomID, senderID) + func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) }, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7787d9f85..7156c11cc 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -169,10 +169,6 @@ type Database interface { GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) - // GetKnownUsers tries to obtain the current mxid for a given user. - GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) - // GetKnownUsers tries to obtain the current senderID for a given user. - GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -190,6 +186,7 @@ type Database interface { ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ) (map[string]*types.HeaderedEvent, error) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) + GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) MaybeRedactEvent( @@ -205,8 +202,12 @@ type UserRoomKeys interface { InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) // SelectUserRoomPrivateKey selects the private key for the given user and room combination SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) + // SelectUserRoomPublicKey selects the public key for the given user and room combination + SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) // SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID. // If a senderKey can't be found, it is omitted in the result. + // TODO: Why is the result map indexed by string not public key? + // TODO: Shouldn't the input & result map be changed to be indexed by string instead of the RoomID struct? SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) } @@ -233,7 +234,6 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) - GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index 22f978bf0..dbb4af34a 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` +const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt + selectUserRoomPublicKeyStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt } @@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL}, }.Prepare(db) } @@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( return result, err } +func (s *userRoomKeysStatements) SelectUserRoomPublicKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PublicKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt) + var result ed25519.PublicKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 6fb57332a..70672a33e 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -251,7 +250,3 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } - -func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return u.d.GetUserIDForSender(ctx, roomID, senderID) -} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index bda51da81..61a3520a4 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -721,6 +721,22 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver }, err } +func (d *Database) GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) { + cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(roomID) + if versionOK { + return cachedRoomVersion, nil + } + + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return "", err + } + if roomInfo == nil { + return "", nil + } + return roomInfo.RoomVersion, nil +} + func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil { @@ -1550,16 +1566,6 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } -func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { - // TODO: Use real logic once DB for pseudoIDs is in place - return spec.NewUserID(string(senderID), true) -} - -func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { - // TODO: Use real logic once DB for pseudoIDs is in place - return spec.SenderID(userID.String()), nil -} - // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) @@ -1718,6 +1724,35 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use return } +// SelectUserRoomPublicKey queries the users room public key. +// If no key exists, returns no key and no error. Otherwise returns +// the key and a database error, if any. +func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return nil + } + + key, sErr = d.UserRoomKeyTable.SelectUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) + if !errors.Is(sErr, sql.ErrNoRows) { + return sErr + } + return nil + }) + return +} + // SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { result = make(map[spec.RoomID]map[string]string, len(publicKeys)) diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 581d83ee4..c7b915c7d 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -163,12 +163,17 @@ func TestUserRoomKeys(t *testing.T) { gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) assert.NoError(t, err) assert.Equal(t, key, gotKey) + pubKey, err := db.SelectUserRoomPublicKey(context.Background(), *userID, *roomID) + assert.NoError(t, err) + assert.Equal(t, key.Public(), pubKey) // Key doesn't exist, we shouldn't get anything back - assert.NoError(t, err) gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) assert.NoError(t, err) assert.Nil(t, gotKey) + pubKey, err = db.SelectUserRoomPublicKey(context.Background(), *userID, *doesNotExist) + assert.NoError(t, err) + assert.Nil(t, pubKey) queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ *roomID: {key.Public().(ed25519.PublicKey)}, diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 8af57ea0e..84c8b54ec 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` +const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt + selectUserRoomPublicKeyStmt *sql.Stmt //selectUserNIDsStmt *sql.Stmt //prepared at runtime } @@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime }.Prepare(db) } @@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( return result, err } +func (s *userRoomKeysStatements) SelectUserRoomPublicKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PublicKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt) + var result ed25519.PublicKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { roomNIDs := make([]any, 0, len(senderKeys)) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index cd0e51686..445c1223f 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -193,6 +193,8 @@ type UserRoomKeys interface { InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) // SelectUserRoomPrivateKey selects the private key for the given user and room combination SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) + // SelectUserRoomPublicKey selects the public key for the given user and room combination + SelectUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PublicKey, error) // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // If a senderKey can't be found, it is omitted in the result. BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 284309481..8802a3c6e 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -50,6 +50,7 @@ func TestUserRoomKeysTable(t *testing.T) { err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { var gotKey, key2, key3 ed25519.PrivateKey + var pubKey ed25519.PublicKey gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) @@ -71,6 +72,9 @@ func TestUserRoomKeysTable(t *testing.T) { gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) assert.NoError(t, err) assert.Equal(t, key, gotKey) + pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID) + assert.NoError(t, err) + assert.Equal(t, key.Public(), pubKey) // try to update an existing key, this should only be done for users NOT on this homeserver var gotPubKey ed25519.PublicKey @@ -82,6 +86,9 @@ func TestUserRoomKeysTable(t *testing.T) { gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) assert.NoError(t, err) assert.Nil(t, gotKey) + pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, pubKey) // query user NIDs for senderKeys var gotKeys map[string]types.UserRoomKeyPair diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index d3f1c9dd2..f28419905 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -94,7 +94,7 @@ type MSC2836EventRelationshipsResponse struct { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), Limited: res.Limited, diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index e32d6a9f2..16fb3efe1 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -525,11 +525,11 @@ type testRoomserverAPI struct { events map[string]*types.HeaderedEvent } -func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } -func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { +func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { return spec.SenderID(userID.String()), nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index c5f2db9c8..d468dfc98 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -377,7 +377,11 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst return sp, fmt.Errorf("unexpected nil state_key") } - userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return sp, err + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey())) if err != nil || userID == nil { return sp, fmt.Errorf("failed getting userID for sender: %w", err) } @@ -404,7 +408,11 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( return } - userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey())) + validRoomID, err := spec.NewRoomID(msg.Event.RoomID()) + if err != nil { + return + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey())) if err != nil || userID == nil { return } @@ -454,7 +462,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) - userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID) + validRoomID, err := spec.NewRoomID(msg.RoomID) + if err != nil { + log.WithFields(log.Fields{ + "event_id": msg.EventID, + "room_id": msg.RoomID, + log.ErrorKey: err, + }).Errorf("roomID is invalid") + return + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, msg.TargetSenderID) if err != nil || userID == nil { log.WithFields(log.Fields{ "event_id": msg.EventID, diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index ab1a7f83d..ce6846ca4 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -139,7 +139,11 @@ func ApplyHistoryVisibilityFilter( if err != nil { return nil, err } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user) + roomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user) if err == nil { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) { eventsFiltered = append(eventsFiltered, ev) diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index f4b6ace59..24ffcc041 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -170,11 +170,15 @@ func TrackChangedUsers( return nil, nil, err } for roomID, state := range stateRes.Rooms { + validRoomID, roomErr := spec.NewRoomID(roomID) + if roomErr != nil { + continue + } for tuple, membership := range state { if membership != spec.Join { continue } - user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) + user, queryErr := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey)) if queryErr != nil || user == nil { continue } @@ -216,13 +220,17 @@ func TrackChangedUsers( return nil, left, err } for roomID, state := range stateRes.Rooms { + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + continue + } for tuple, membership := range state { if membership != spec.Join { continue } // new user who we weren't previously sharing rooms with if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { - user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) + user, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey)) if err != nil || user == nil { continue } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index efa641475..3f5e990c4 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -64,7 +64,7 @@ type mockRoomserverAPI struct { roomIDToJoinedMembers map[string][]string } -func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 4ee7c8605..af8ab0102 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -101,13 +101,20 @@ func (n *Notifier) OnNewEvent( n._removeEmptyUserStreams() if ev != nil { + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: RoomID is invalid", + ) + return + } // Map this event's room_id to a list of joined users, and wake them up. usersToNotify := n._joinedUsers(ev.RoomID()) // Map this event's room_id to a list of peeking devices, and wake them up. peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { - targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey())) + targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey())) if err != nil { log.WithError(err).WithField("event_id", ev.EventID()).Errorf( "Notifier.OnNewEvent: Failed to find the userID for this event", diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index 7076f7134..f86301a06 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -109,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { type TestRoomServer struct{ api.SyncRoomserverAPI } -func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 55fd3c5a2..649d77b41 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -200,10 +200,10 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) @@ -211,7 +211,7 @@ func Context( if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) @@ -224,14 +224,14 @@ func Context( } } - ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + ev := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, requestedEvent) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), } diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index de790e5cd..09c2aef02 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -102,14 +102,28 @@ func GetEvent( } sender := spec.UserID{} - senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID()) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("roomID is invalid"), + } + } + senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID()) if err == nil && senderUserID != nil { sender = *senderUserID } sk := events[0].StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey())) + evRoomID, err := spec.NewRoomID(events[0].RoomID()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("roomID is invalid"), + } + } + skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*events[0].StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index cf6769ba4..5e5d0125f 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -152,7 +152,15 @@ func GetMemberships( } } - userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID()) if err != nil || userID == nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") return util.JSONResponse{ @@ -175,7 +183,7 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })}, } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 6784a27bd..937e20ad8 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -273,7 +273,7 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })...) } @@ -389,7 +389,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), start, end, err } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 6efa065a9..17933b2fb 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -110,19 +110,24 @@ func Relations( return util.ErrorResponse(err) } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.ErrorResponse(err) + } + // Convert the events into client events, and optionally filter based on the event // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 7d9182f47..d892b604a 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -205,9 +205,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + validRoomID, roomErr := spec.NewRoomID(ev.RoomID()) + if err != nil { + logrus.WithError(roomErr).WithField("room_id", ev.RoomID()).Warn("failed to query userprofile") + continue + } + userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID()) if queryErr != nil { - logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + logrus.WithError(queryErr).WithField("sender_id", ev.SenderID()).Warn("failed to query userprofile") continue } @@ -231,14 +236,19 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts } sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + validRoomID, roomErr := spec.NewRoomID(event.RoomID()) + if err != nil { + logrus.WithError(roomErr).WithField("room_id", event.RoomID()).Warn("failed to query userprofile") + continue + } + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -248,10 +258,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts Context: SearchContextResponse{ Start: startToken.String(), End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), ProfileInfo: profileInfos, @@ -272,7 +282,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }) } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 5eb094ca3..f6d7fb4eb 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -25,7 +25,7 @@ import ( type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } -func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 799e3d166..1827218b6 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -114,7 +114,14 @@ func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Dev }).WithError(err).Warnf("Failed to add transaction ID to event") continue } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID) + roomID, err := spec.NewRoomID(in[i].RoomID()) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Room ID is invalid") + continue + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { logrus.WithFields(logrus.Fields{ "event_id": out[i].EventID(), @@ -515,7 +522,11 @@ func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userI if err != nil { return "", "" } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser) + roomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return "", "" + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *fullUser) if err != nil { return "", "" } diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 3a5badd92..7c29d84ae 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -65,14 +65,18 @@ func (p *InviteStreamProvider) IncrementalSync( for roomID, inviteEvent := range invites { user := spec.UserID{} - sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) + validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) + if err != nil { + continue + } + sender, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, inviteEvent.SenderID()) if err == nil && sender != nil { user = *sender } sk := inviteEvent.StateKey() if sk != nil && *sk != "" { - skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) + skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index f728d4aea..7939dd8fa 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Join[delta.RoomID] = jr @@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Peek[delta.RoomID] = jr @@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Leave[delta.RoomID] = lr @@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) return jr, nil @@ -585,6 +585,10 @@ func (p *PDUStreamProvider) lazyLoadMembers( if len(timelineEvents) == 0 { return stateEvents, nil } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } // Work out which memberships to include timelineUsers := make(map[string]struct{}) if !incremental { @@ -606,8 +610,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list userID := "" - stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey())) - if err == nil && stateKeyUserID != nil { + stateKeyUserID, queryErr := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) + if queryErr == nil && stateKeyUserID != nil { userID = stateKeyUserID.String() } if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index b9f13c517..19815b79b 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,7 +40,7 @@ type syncRoomserverAPI struct { rooms []*test.Room } -func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 433be39f8..6f03d9ff0 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -52,14 +52,18 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, continue // TODO: shouldn't happen? } sender := spec.UserID{} - userID, err := userIDForSender(se.RoomID(), se.SenderID()) + validRoomID, err := spec.NewRoomID(se.RoomID()) + if err != nil { + continue + } + userID, err := userIDForSender(*validRoomID, se.SenderID()) if err == nil && userID != nil { sender = *userID } sk := se.StateKey() if sk != nil && *sk != "" { - skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) + skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk)) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -95,14 +99,18 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp // It provides default logic for event.SenderID & event.StateKey -> userID conversions. func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { sender := spec.UserID{} - userID, err := userIDQuery(event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return ClientEvent{} + } + userID, err := userIDQuery(*validRoomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, err := userIDQuery(*validRoomID, spec.SenderID(*event.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString diff --git a/test/room.go b/test/room.go index b19c57ddc..da09de7c2 100644 --- a/test/room.go +++ b/test/room.go @@ -39,7 +39,7 @@ var ( roomIDCounter = int64(0) ) -func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index b2dc477aa..9cb9419d4 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -302,14 +302,18 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: sender := spec.UserID{} - userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, roomErr := spec.NewRoomID(event.RoomID()) + if roomErr != nil { + return roomErr + } + userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if queryErr == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) if queryErr == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -544,14 +548,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype } sender := spec.UserID{} - userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) if queryErr == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -644,7 +652,11 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { user := "" - sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err == nil { user = sender.String() } @@ -682,7 +694,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + rule, err := eval.MatchEvent(event.PDU, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { @@ -790,7 +802,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes } default: - sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err != nil { logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) return nil, err @@ -818,7 +834,13 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) return nil, err } - localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID) + roomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + logger.WithError(err).Errorf("event roomID is invalid %s", event.RoomID()) + return nil, err + } + + localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) return nil, err diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 954247155..4dc81e74a 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -47,7 +47,7 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } -func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -68,13 +68,13 @@ func Test_evaluatePushRules(t *testing.T) { }{ { name: "m.receipt doesn't notify", - eventContent: `{"type":"m.receipt"}`, + eventContent: `{"type":"m.receipt","room_id":"!room:example.com"}`, wantAction: pushrules.UnknownAction, wantActions: nil, }, { name: "m.reaction doesn't notify", - eventContent: `{"type":"m.reaction"}`, + eventContent: `{"type":"m.reaction","room_id":"!room:example.com"}`, wantAction: pushrules.DontNotifyAction, wantActions: []*pushrules.Action{ { @@ -84,7 +84,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message notifies", - eventContent: `{"type":"m.room.message"}`, + eventContent: `{"type":"m.room.message","room_id":"!room:example.com"}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ @@ -93,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message highlights", - eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`, + eventContent: `{"type":"m.room.message", "content": {"body": "test"},"room_id":"!room:example.com"}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ From 5aaa539e3eb8ef1f1f601468c786f2d7f891394f Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 14 Jun 2023 16:42:09 +0100 Subject: [PATCH 25/35] Fix senderID/key conversions --- roomserver/internal/perform/perform_create_room.go | 3 ++- roomserver/storage/postgres/user_room_keys_table.go | 2 +- roomserver/storage/sqlite3/user_room_keys_table.go | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index fd8055e09..dcaf8dca6 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -16,6 +16,7 @@ package perform import ( "context" + "crypto/ed25519" "encoding/json" "fmt" "net/http" @@ -74,7 +75,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo JSON: spec.InternalServerError{}, } } - senderID = spec.SenderID(spec.Base64Bytes(key).Encode()) + senderID = spec.SenderID(spec.Base64Bytes(key.Public().(ed25519.PublicKey)).Encode()) } else { senderID = spec.SenderID(userID.String()) } diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index dbb4af34a..dd4d9ab13 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -145,7 +145,7 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { return nil, err } - result[string(publicKey)] = userRoomKeyPair + result[spec.Base64Bytes(publicKey).Encode()] = userRoomKeyPair } return result, rows.Err() } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 84c8b54ec..d58b8ac3f 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) const userRoomKeysSchema = ` @@ -159,7 +160,7 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { return nil, err } - result[string(publicKey)] = userRoomKeyPair + result[spec.Base64Bytes(publicKey).Encode()] = userRoomKeyPair } return result, rows.Err() } From 3f4df25b31a403e936488a1920d1aed3de471a71 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 14 Jun 2023 17:04:19 +0100 Subject: [PATCH 26/35] Add missing dep --- roomserver/storage/postgres/user_room_keys_table.go | 1 + 1 file changed, 1 insertion(+) diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index dd4d9ab13..202b0abc1 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) const userRoomKeysSchema = ` From 8cf6c381e21d0710f0290c97dfa5616036749a81 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 14 Jun 2023 17:11:27 +0100 Subject: [PATCH 27/35] Fix senderID/key conversion unit tests --- roomserver/storage/shared/storage_test.go | 2 +- roomserver/storage/tables/user_room_keys_table_test.go | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index c7b915c7d..612e4ef06 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -183,7 +183,7 @@ func TestUserRoomKeys(t *testing.T) { assert.NoError(t, err) wantKeys := map[spec.RoomID]map[string]string{ *roomID: { - string(key.Public().(ed25519.PublicKey)): userID.String(), + spec.Base64Bytes(key.Public().(ed25519.PublicKey)).Encode(): userID.String(), }, } assert.Equal(t, wantKeys, userIDs) diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 8802a3c6e..2809771b4 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ed255192 "golang.org/x/crypto/ed25519" ) @@ -101,8 +102,8 @@ func TestUserRoomKeysTable(t *testing.T) { assert.NotNil(t, gotKeys) wantKeys := map[string]types.UserRoomKeyPair{ - string(key2.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID}, - string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2}, + string(spec.Base64Bytes(key2.Public().(ed25519.PublicKey)).Encode()): {RoomNID: roomNID, EventStateKeyNID: userNID}, + string(spec.Base64Bytes(key3.Public().(ed25519.PublicKey)).Encode()): {RoomNID: roomNID, EventStateKeyNID: userNID2}, } assert.Equal(t, wantKeys, gotKeys) From 420e7ec81fedf9ff531c75ece4c80a9b63046ba9 Mon Sep 17 00:00:00 2001 From: Josh Qou <97894002+joshqou@users.noreply.github.com> Date: Thu, 15 Jun 2023 12:28:34 +0100 Subject: [PATCH 28/35] Fix unsafe hotserving behaviour for multimedia uploads. (#3113) Return multimedia with a disposition type of attachment instead of inline. NVT#1548992 Signed-off-by: Josh Qou [jqou@icloud.com](mailto:jqou@icloud.com) Co-authored-by: Jon --- mediaapi/routing/download.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index e9f161a3c..8fb1b6534 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -341,6 +341,7 @@ func (r *downloadRequest) addDownloadFilenameToHeaders( } if len(filename) == 0 { + w.Header().Set("Content-Disposition", "attachment") return nil } @@ -376,13 +377,13 @@ func (r *downloadRequest) addDownloadFilenameToHeaders( // that would otherwise be parsed as a control character in the // Content-Disposition header w.Header().Set("Content-Disposition", fmt.Sprintf( - `inline; filename=%s%s%s`, + `attachment; filename=%s%s%s`, quote, unescaped, quote, )) } else { // For UTF-8 filenames, we quote always, as that's the standard w.Header().Set("Content-Disposition", fmt.Sprintf( - `inline; filename*=utf-8''%s`, + `attachment; filename*=utf-8''%s`, url.QueryEscape(unescaped), )) } From d13466c1eed040a97048c8b30b64df9f4bc84727 Mon Sep 17 00:00:00 2001 From: CicadaCinema <52425971+CicadaCinema@users.noreply.github.com> Date: Sun, 18 Jun 2023 22:54:16 +0100 Subject: [PATCH 29/35] rearrange order of sections about signing keys and configuring dendrite, fix a dead link (#3114) I thought I would rearrange these pages since the configuration step requires that a signing key has been generated. Co-authored-by: kegsay --- docs/installation/manual/{4_signingkey.md => 3_signingkey.md} | 2 +- .../manual/{3_configuration.md => 4_configuration.md} | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename docs/installation/manual/{4_signingkey.md => 3_signingkey.md} (99%) rename docs/installation/manual/{3_configuration.md => 4_configuration.md} (98%) diff --git a/docs/installation/manual/4_signingkey.md b/docs/installation/manual/3_signingkey.md similarity index 99% rename from docs/installation/manual/4_signingkey.md rename to docs/installation/manual/3_signingkey.md index bd9c242ab..91289fd6a 100644 --- a/docs/installation/manual/4_signingkey.md +++ b/docs/installation/manual/3_signingkey.md @@ -2,7 +2,7 @@ title: Generating signing keys parent: Manual grand_parent: Installation -nav_order: 4 +nav_order: 3 permalink: /installation/manual/signingkeys --- diff --git a/docs/installation/manual/3_configuration.md b/docs/installation/manual/4_configuration.md similarity index 98% rename from docs/installation/manual/3_configuration.md rename to docs/installation/manual/4_configuration.md index a9dd81c87..624cc4155 100644 --- a/docs/installation/manual/3_configuration.md +++ b/docs/installation/manual/4_configuration.md @@ -2,7 +2,7 @@ title: Configuring Dendrite parent: Manual grand_parent: Installation -nav_order: 3 +nav_order: 4 permalink: /installation/manual/configuration --- @@ -21,7 +21,7 @@ sections: First of all, you will need to configure the server name of your Matrix homeserver. This must match the domain name that you have selected whilst [configuring the domain -name delegation](domainname#delegation). +name delegation](../domainname#delegation). In the `global` section, set the `server_name` to your delegated domain name: From a734b112c6577a23b87c6b54c50fb2e9a629cf2b Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 20 Jun 2023 16:52:29 +0200 Subject: [PATCH 30/35] Fix backfilling (#3117) This should fix two issues with backfilling: 1. right after creating and joining a room over federation, we are doing a `/backfill` request, which would return redacted events, because the `authEvents` are empty. Even though the spec states that, in the absence of a history visibility event, it should be handled as `shared`. 2. `gomatrixserverlib: unsupported room version ''` - because, well, we were never setting the `roomInfo` field.. --- roomserver/auth/auth.go | 4 ---- roomserver/internal/perform/perform_backfill.go | 12 +++++------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index d6c10cf92..df95851e3 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -30,10 +30,6 @@ func IsServerAllowed( serverCurrentlyInRoom bool, authEvents []gomatrixserverlib.PDU, ) bool { - // In practice should not happen, but avoids unneeded CPU cycles - if serverName == "" || len(authEvents) == 0 { - return false - } historyVisibility := HistoryVisibilityForRoom(authEvents) // 1. If the history_visibility was set to world_readable, allow. diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 533ad25bf..3fdc8e4d0 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -114,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub() { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) + requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers, info.RoomVersion) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass @@ -265,7 +265,7 @@ type backfillRequester struct { eventIDToBeforeStateIDs map[string][]string eventIDMap map[string]gomatrixserverlib.PDU historyVisiblity gomatrixserverlib.HistoryVisibility - roomInfo types.RoomInfo + roomVersion gomatrixserverlib.RoomVersion } func newBackfillRequester( @@ -274,6 +274,7 @@ func newBackfillRequester( virtualHost spec.ServerName, isLocalServerName func(spec.ServerName) bool, bwExtrems map[string][]string, preferServers []spec.ServerName, + roomVersion gomatrixserverlib.RoomVersion, ) *backfillRequester { preferServer := make(map[spec.ServerName]bool) for _, p := range preferServers { @@ -290,6 +291,7 @@ func newBackfillRequester( bwExtrems: bwExtrems, preferServer: preferServer, historyVisiblity: gomatrixserverlib.HistoryVisibilityShared, + roomVersion: roomVersion, } } @@ -537,15 +539,11 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, } eventNIDs := make([]types.EventNID, len(nidMap)) i := 0 - roomNID := b.roomInfo.RoomNID for _, nid := range nidMap { eventNIDs[i] = nid.EventNID i++ - if roomNID == 0 { - roomNID = nid.RoomNID - } } - eventsWithNids, err := b.db.Events(ctx, b.roomInfo.RoomVersion, eventNIDs) + eventsWithNids, err := b.db.Events(ctx, b.roomVersion, eventNIDs) if err != nil { logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err From 45082d4dcefadceada1b4374f3876365887cfd4a Mon Sep 17 00:00:00 2001 From: santhoshivan23 <47689668+santhoshivan23@users.noreply.github.com> Date: Thu, 22 Jun 2023 22:07:21 +0530 Subject: [PATCH 31/35] feat: admin APIs for token authenticated registration (#3101) ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Santhoshivan Amudhan santhoshivan23@gmail.com` --- clientapi/admin_test.go | 638 ++++++++++++++++++ clientapi/api/api.go | 8 + clientapi/routing/admin.go | 242 +++++++ clientapi/routing/routing.go | 30 + setup/config/config_clientapi.go | 5 + userapi/api/api.go | 6 + userapi/internal/user_api.go | 32 + userapi/storage/interface.go | 11 + .../postgres/registration_tokens_table.go | 222 ++++++ userapi/storage/postgres/storage.go | 5 + userapi/storage/shared/storage.go | 38 ++ .../sqlite3/registration_tokens_table.go | 222 ++++++ userapi/storage/sqlite3/storage.go | 6 +- userapi/storage/tables/interface.go | 10 + 14 files changed, 1474 insertions(+), 1 deletion(-) create mode 100644 userapi/storage/postgres/registration_tokens_table.go create mode 100644 userapi/storage/sqlite3/registration_tokens_table.go diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 1145cb12d..9d2acd68e 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -2,6 +2,7 @@ package clientapi import ( "context" + "fmt" "net/http" "net/http/httptest" "reflect" @@ -23,12 +24,649 @@ import ( "github.com/matrix-org/util" "github.com/tidwall/gjson" + capi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" uapi "github.com/matrix-org/dendrite/userapi/api" ) +func TestAdminCreateToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token1", + }, + ), + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token2", + }, + ), + }, + { + name: "Alice can create a token without specifyiing any information", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{}), + }, + { + name: "Alice can to create a token specifying a name", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token3", + }, + ), + }, + { + name: "Alice cannot to create a token that already exists", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token3", + }, + ), + }, + { + name: "Alice can create a token specifying valid params", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token4", + "uses_allowed": 5, + "expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid name", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token@", + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid uses_allowed", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token5", + "uses_allowed": -1, + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid expiry_time", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token6", + "expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice cannot to create a token specifying invalid length", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "length": 80, + }, + ), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new") + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new", tc.requestOpt) + } + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminListRegistrationTokens(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("valid"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("invalid"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + valid string + isValidSpecified bool + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + isValidSpecified: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + isValidSpecified: false, + }, + { + name: "Alice can list all tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can list all valid tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + valid: "true", + isValidSpecified: true, + }, + { + name: "Alice can list all invalid tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + valid: "false", + isValidSpecified: true, + }, + { + name: "No response when valid has a bad value", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + valid: "trueee", + isValidSpecified: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + var path string + if tc.isValidSpecified { + path = fmt.Sprintf("/_dendrite/admin/registrationTokens?valid=%v", tc.valid) + } else { + path = "/_dendrite/admin/registrationTokens" + } + req := test.NewRequest(t, http.MethodGet, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminGetRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + token string + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + }, + { + name: "Alice can GET alice_token1", + token: "alice_token1", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can GET alice_token2", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + }, + { + name: "Alice cannot GET a token that does not exists", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token3", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodGet, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminDeleteRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + token string + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + }, + { + name: "Alice can DELETE alice_token1", + token: "alice_token1", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can DELETE alice_token2", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodDelete, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminUpdateRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + testCases := []struct { + name string + requestingUser *test.User + method string + token string + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }, + ), + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }, + ), + }, + { + name: "Alice can UPDATE a token's uses_allowed property", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }), + }, + { + name: "Alice can UPDATE a token's expiry_time property", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice can UPDATE a token's uses_allowed and expiry_time property", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 20, + "expiry_time": time.Now().Add(10*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice CANNOT update a token with invalid properties", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token2", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": -5, + "expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice CANNOT UPDATE a token that does not exist", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token9", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 100, + }, + ), + }, + { + name: "Alice can UPDATE token specifying uses_allowed as null - Valid for infinite uses", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": nil, + }, + ), + }, + { + name: "Alice can UPDATE token specifying expiry_time AS null - Valid for infinite time", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "expiry_time": nil, + }, + ), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodPut, path) + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPut, path, tc.requestOpt) + } + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func getPointer[T any](s T) *T { + return &s +} + func TestAdminResetPassword(t *testing.T) { aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) diff --git a/clientapi/api/api.go b/clientapi/api/api.go index 23974c865..28ff593fc 100644 --- a/clientapi/api/api.go +++ b/clientapi/api/api.go @@ -21,3 +21,11 @@ type ExtraPublicRoomsProvider interface { // Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately. Rooms() []fclient.PublicRoom } + +type RegistrationToken struct { + Token *string `json:"token"` + UsesAllowed *int32 `json:"uses_allowed"` + Pending *int32 `json:"pending"` + Completed *int32 `json:"completed"` + ExpiryTime *int64 `json:"expiry_time"` +} diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 3d64454c4..519666076 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "net/http" + "regexp" + "strconv" "time" "github.com/gorilla/mux" @@ -16,14 +18,254 @@ import ( "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" + "golang.org/x/exp/constraints" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" ) +var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") + +func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + if !cfg.RegistrationRequiresToken { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"), + } + } + request := struct { + Token string `json:"token"` + UsesAllowed *int32 `json:"uses_allowed,omitempty"` + ExpiryTime *int64 `json:"expiry_time,omitempty"` + Length int32 `json:"length"` + }{} + + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } + } + + token := request.Token + usesAllowed := request.UsesAllowed + expiryTime := request.ExpiryTime + length := request.Length + + if len(token) == 0 { + if length == 0 { + // length not provided in request. Assign default value of 16. + length = 16 + } + // token not present in request body. Hence, generate a random token. + if length <= 0 || length > 64 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("length must be greater than zero and not greater than 64"), + } + } + token = util.RandomString(int(length)) + } + + if len(token) > 64 { + //Token present in request body, but is too long. + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must not be longer than 64"), + } + } + + isTokenValid := validRegistrationTokenRegex.Match([]byte(token)) + if !isTokenValid { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"), + } + } + // At this point, we have a valid token, either through request body or through random generation. + if usesAllowed != nil && *usesAllowed < 0 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } + } + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } + } + pending := int32(0) + completed := int32(0) + // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB) + registrationToken := &clientapi.RegistrationToken{ + Token: &token, + UsesAllowed: usesAllowed, + Pending: &pending, + Completed: &completed, + ExpiryTime: expiryTime, + } + created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) + if !created { + return util.JSONResponse{ + Code: http.StatusConflict, + JSON: map[string]string{ + "error": fmt.Sprintf("token: %s already exists", token), + }, + } + } + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "token": token, + "uses_allowed": getReturnValue(usesAllowed), + "pending": pending, + "completed": completed, + "expiry_time": getReturnValue(expiryTime), + }, + } +} + +func getReturnValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + queryParams := req.URL.Query() + returnAll := true + valid := true + validQuery, ok := queryParams["valid"] + if ok { + returnAll = false + validValue, err := strconv.ParseBool(validQuery[0]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("invalid 'valid' query parameter"), + } + } + valid = validValue + } + tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.ErrorUnknown, + } + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "registration_tokens": tokens, + }, + } +} + +func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } + } + return util.JSONResponse{ + Code: 200, + JSON: token, + } +} + +func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{}, + } +} + +func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + request := make(map[string]*int64) + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } + } + newAttributes := make(map[string]interface{}) + usesAllowed, ok := request["uses_allowed"] + if ok { + // Only add usesAllowed to newAtrributes if it is present and valid + if usesAllowed != nil && *usesAllowed < 0 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } + } + newAttributes["usesAllowed"] = usesAllowed + } + expiryTime, ok := request["expiry_time"] + if ok { + // Only add expiryTime to newAtrributes if it is present and valid + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } + } + newAttributes["expiryTime"] = expiryTime + } + if len(newAttributes) == 0 { + // No attributes to update. Return existing token + return AdminGetRegistrationToken(req, cfg, userAPI) + } + updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } + } + return util.JSONResponse{ + Code: 200, + JSON: *updatedToken, + } +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d3f19cae1..ab4aefddd 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -162,6 +162,36 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } + dendriteAdminRouter.Handle("/admin/registrationTokens/new", + httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminCreateNewRegistrationToken(req, cfg, userAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + + dendriteAdminRouter.Handle("/admin/registrationTokens", + httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminListRegistrationTokens(req, cfg, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + dendriteAdminRouter.Handle("/admin/registrationTokens/{token}", + httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + switch req.Method { + case http.MethodGet: + return AdminGetRegistrationToken(req, cfg, userAPI) + case http.MethodPut: + return AdminUpdateRegistrationToken(req, cfg, userAPI) + case http.MethodDelete: + return AdminDeleteRegistrationToken(req, cfg, userAPI) + default: + return util.MatrixErrorResponse( + 404, + string(spec.ErrorNotFound), + "unknown method", + ) + } + }), + ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index b6c74a75f..44136e2a0 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -13,6 +13,10 @@ type ClientAPI struct { // secrets) RegistrationDisabled bool `yaml:"registration_disabled"` + // If set, requires users to submit a token during registration. + // Tokens can be managed using admin API. + RegistrationRequiresToken bool `yaml:"registration_requires_token"` + // Enable registration without captcha verification or shared secret. // This option is populated by the -really-enable-open-registration // command line parameter as it is not recommended. @@ -56,6 +60,7 @@ type ClientAPI struct { func (c *ClientAPI) Defaults(opts DefaultOpts) { c.RegistrationSharedSecret = "" + c.RegistrationRequiresToken = false c.RecaptchaPublicKey = "" c.RecaptchaPrivateKey = "" c.RecaptchaEnabled = false diff --git a/userapi/api/api.go b/userapi/api/api.go index 050402645..a0dce9758 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" ) @@ -94,6 +95,11 @@ type ClientUserAPI interface { QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error + PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) + PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error + PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 32f3d84b5..4305c13a9 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -33,6 +33,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" @@ -63,6 +64,37 @@ type UserInternalAPI struct { Updater *DeviceListUpdater } +func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) { + exists, err := a.DB.RegistrationTokenExists(ctx, *registrationToken.Token) + if err != nil { + return false, err + } + if exists { + return false, fmt.Errorf("token: %s already exists", *registrationToken.Token) + } + _, err = a.DB.InsertRegistrationToken(ctx, registrationToken) + if err != nil { + return false, fmt.Errorf("Error creating token: %s"+err.Error(), *registrationToken.Token) + } + return true, nil +} + +func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { + return a.DB.ListRegistrationTokens(ctx, returnAll, valid) +} + +func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + return a.DB.GetRegistrationToken(ctx, tokenString) +} + +func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error { + return a.DB.DeleteRegistrationToken(ctx, tokenString) +} + +func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) { + return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 4f5e99a8a..125b31585 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/userapi/api" @@ -30,6 +31,15 @@ import ( "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokens interface { + RegistrationTokenExists(ctx context.Context, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) + ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, tokenString string) error + UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) +} + type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) @@ -144,6 +154,7 @@ type UserDatabase interface { Pusher Statistics ThreePID + RegistrationTokens } type KeyChangeDatabase interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go new file mode 100644 index 000000000..3c3e3fdd9 --- /dev/null +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -0,0 +1,222 @@ +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +const listAllTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens" + +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt +} + +func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatement, insertTokenSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatement) + _, err := stmt.ExecContext( + ctx, + *registrationToken.Token, + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), + *registrationToken.Pending, + *registrationToken.Completed) + if err != nil { + return false, err + } + return true, nil +} + +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { + var stmt *sql.Stmt + var tokens []api.RegistrationToken + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + var rows *sql.Rows + var err error + if returnAll { + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) + rows, err = stmt.QueryContext(ctx) + } else if valid { + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } else { + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } + if err != nil { + return tokens, err + } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") + for rows.Next() { + err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return tokens, err + } + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime + + tokenMap := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + tokens = append(tokens, tokenMap) + } + return tokens, rows.Err() +} + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + return nil +} + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 72e7c9cd9..d01ccc776 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * return nil, err } + registationTokensTable, err := NewPostgresRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err) + } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) @@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * ThreePIDs: threePIDTable, Pushers: pusherTable, Notifications: notificationsTable, + RegistrationTokens: registationTokensTable, Stats: statsTable, ServerName: serverName, DB: db, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 537bbbf4a..b7acb2035 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "golang.org/x/crypto/bcrypt" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -43,6 +44,7 @@ import ( type Database struct { DB *sql.DB Writer sqlutil.Writer + RegistrationTokens tables.RegistrationTokensTable Accounts tables.AccountsTable Profiles tables.ProfileTable AccountDatas tables.AccountDataTable @@ -78,6 +80,42 @@ const ( loginTokenByteLength = 32 ) +func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { + return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token) +} + +func (d *Database) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (created bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, registrationToken) + return err + }) + return +} + +func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { + return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid) +} + +func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString) +} + +func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) + return err + }) + return +} + +func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes) + return err + }) + return +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go new file mode 100644 index 000000000..897954731 --- /dev/null +++ b/userapi/storage/sqlite3/registration_tokens_table.go @@ -0,0 +1,222 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +const listAllTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens" + +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt +} + +func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatement, insertTokenSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatement) + _, err := stmt.ExecContext( + ctx, + *registrationToken.Token, + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), + *registrationToken.Pending, + *registrationToken.Completed) + if err != nil { + return false, err + } + return true, nil +} + +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { + var stmt *sql.Stmt + var tokens []api.RegistrationToken + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + var rows *sql.Rows + var err error + if returnAll { + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) + rows, err = stmt.QueryContext(ctx) + } else if valid { + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } else { + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } + if err != nil { + return tokens, err + } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") + for rows.Next() { + err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return tokens, err + } + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime + + tokenMap := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + tokens = append(tokens, tokenMap) + } + return tokens, rows.Err() +} + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + return nil +} + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index acd9678f2..48f5c842b 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -50,7 +50,10 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti if err = m.Up(ctx); err != nil { return nil, err } - + registationTokensTable, err := NewSQLiteRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteRegistrationsTokenTable: %w", err) + } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) @@ -130,6 +133,7 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti LoginTokenLifetime: loginTokenLifetime, BcryptCost: bcryptCost, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + RegistrationTokens: registationTokensTable, }, nil } diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 3c6214e7c..3a0be73e4 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -25,10 +25,20 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokensTable interface { + RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error) + ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error + UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) +} + type AccountDataTable interface { InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) From a5ea928d0fc52f0efb6607791ac59e18103b57de Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 28 Jun 2023 10:05:00 +0200 Subject: [PATCH 32/35] Fix syncAPI redactions (#3118) Previously we were setting `redacted_because` to the PDU event, but as per the spec it should really be a client event. This fixes it. --- internal/eventutil/events.go | 14 +++++- syncapi/consumers/roomserver.go | 2 +- syncapi/storage/interface.go | 2 +- syncapi/storage/shared/storage_consumer.go | 4 +- syncapi/storage/storage_test.go | 51 ++++++++++++++++++++++ 5 files changed, 67 insertions(+), 6 deletions(-) diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 0f73db2d5..56ee576a0 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" @@ -169,13 +170,22 @@ func truncateAuthAndPrevEvents(auth, prev []string) ( // RedactEvent redacts the given event and sets the unsigned field appropriately. This should be used by // downstream components to the roomserver when an OutputTypeRedactedEvent occurs. -func RedactEvent(redactionEvent, redactedEvent gomatrixserverlib.PDU) error { +func RedactEvent(ctx context.Context, redactionEvent, redactedEvent gomatrixserverlib.PDU, querier api.QuerySenderIDAPI) error { // sanity check if redactionEvent.Type() != spec.MRoomRedaction { return fmt.Errorf("RedactEvent: redactionEvent isn't a redaction event, is '%s'", redactionEvent.Type()) } redactedEvent.Redact() - if err := redactedEvent.SetUnsignedField("redacted_because", redactionEvent); err != nil { + validRoomID, err := spec.NewRoomID(redactionEvent.RoomID()) + if err != nil { + return err + } + senderID, err := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) + if err != nil { + return err + } + redactedBecause := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, *senderID, redactionEvent.StateKey()) + if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil { return err } // NOTSPEC: sytest relies on this unspecced field existing :( diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index d468dfc98..90f9ff67d 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -151,7 +151,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms func (s *OutputRoomEventConsumer) onRedactEvent( ctx context.Context, msg api.OutputRedactedEvent, ) error { - err := s.db.RedactEvent(ctx, msg.RedactedEventID, msg.RedactedBecause) + err := s.db.RedactEvent(ctx, msg.RedactedEventID, msg.RedactedBecause, s.rsAPI) if err != nil { log.WithError(err).Error("RedactEvent error'd") return err diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 8798b62ec..243b2592a 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -174,7 +174,7 @@ type Database interface { // goes wrong. PutFilter(ctx context.Context, localpart string, filter *synctypes.Filter) (string, error) // RedactEvent wipes an event in the database and sets the unsigned.redacted_because key to the redaction event - RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *rstypes.HeaderedEvent) error + RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *rstypes.HeaderedEvent, querier api.QuerySenderIDAPI) error // StoreReceipt stores new receipt events StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 1827218b6..746a324fa 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -364,7 +364,7 @@ func (d *Database) PutFilter( return filterID, err } -func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *rstypes.HeaderedEvent) error { +func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *rstypes.HeaderedEvent, querier api.QuerySenderIDAPI) error { redactedEvents, err := d.Events(ctx, []string{redactedEventID}) if err != nil { return err @@ -375,7 +375,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda } eventToRedact := redactedEvents[0].PDU redactionEvent := redactedBecause.PDU - if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { + if err = eventutil.RedactEvent(ctx, redactionEvent, eventToRedact, querier); err != nil { return err } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index bc64aa50f..f56e44a30 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage" @@ -19,6 +20,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" ) var ctx = context.Background() @@ -978,3 +980,52 @@ func TestRecentEvents(t *testing.T) { } }) } + +type FakeQuerier struct { + api.QuerySenderIDAPI +} + +func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + +func TestRedaction(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) + redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()}) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := MustCreateDatabase(t, dbType) + t.Cleanup(close) + MustWriteEvents(t, db, room.Events()) + + err := db.RedactEvent(context.Background(), redactedEvent.EventID(), redactionEvent, &FakeQuerier{}) + if err != nil { + t.Fatal(err) + } + + evs, err := db.Events(context.Background(), []string{redactedEvent.EventID()}) + if err != nil { + t.Fatal(err) + } + + if len(evs) != 1 { + t.Fatalf("expected 1 event, got %d", len(evs)) + } + + // check a few fields which shouldn't be there in unsigned + authEvs := gjson.GetBytes(evs[0].Unsigned(), "redacted_because.auth_events") + if authEvs.Exists() { + t.Error("unexpected auth_events in redacted event") + } + prevEvs := gjson.GetBytes(evs[0].Unsigned(), "redacted_because.prev_events") + if prevEvs.Exists() { + t.Error("unexpected auth_events in redacted event") + } + depth := gjson.GetBytes(evs[0].Unsigned(), "redacted_because.depth") + if depth.Exists() { + t.Error("unexpected auth_events in redacted event") + } + }) +} From 4722f12fab65f3247cd253825d86206bfbfc6f95 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 28 Jun 2023 20:18:07 +0200 Subject: [PATCH 33/35] Fix setting `displayname` and `avatar_url` (#3125) As per the spec, `displayname` and `avatar_url` may be empty. --- clientapi/routing/profile.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 8a44834e1..c89ece41f 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -104,12 +104,6 @@ func SetAvatarURL( if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } - if r.AvatarURL == "" { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("'avatar_url' must be supplied."), - } - } localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -199,12 +193,6 @@ func SetDisplayName( if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } - if r.DisplayName == "" { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("'displayname' must be supplied."), - } - } localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { From 23cd7877a14bca5315467591cd47a7d51aec22ce Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 28 Jun 2023 20:29:49 +0200 Subject: [PATCH 34/35] Add `MXIDMapping` for pseudoID rooms (#3112) Add `MXIDMapping` on membership events when creating/joining rooms. --- clientapi/routing/membership.go | 18 +-- clientapi/routing/profile.go | 20 +-- clientapi/routing/redaction.go | 4 +- clientapi/routing/sendevent.go | 53 +++++-- clientapi/routing/server_notices.go | 4 +- federationapi/consumers/roomserver.go | 15 +- federationapi/internal/perform.go | 22 ++- federationapi/routing/join.go | 12 ++ go.mod | 4 +- go.sum | 4 +- roomserver/api/api.go | 7 +- roomserver/api/query.go | 2 + roomserver/internal/alias.go | 11 +- roomserver/internal/api.go | 50 ++++++- roomserver/internal/input/input.go | 2 +- roomserver/internal/input/input_events.go | 19 ++- .../internal/perform/perform_backfill.go | 2 +- .../internal/perform/perform_create_room.go | 89 +++++++++--- roomserver/internal/perform/perform_invite.go | 17 +++ roomserver/internal/perform/perform_join.go | 39 +++++- roomserver/internal/perform/perform_leave.go | 9 +- roomserver/internal/query/query.go | 3 + roomserver/roomserver_test.go | 10 +- roomserver/storage/interface.go | 5 +- roomserver/storage/shared/storage.go | 104 +++++++------- .../storage/sqlite3/user_room_keys_table.go | 5 +- roomserver/types/headered_event.go | 5 + syncapi/consumers/roomserver.go | 72 +++++++--- syncapi/routing/search_test.go | 1 + .../postgres/current_room_state_table.go | 4 +- syncapi/storage/postgres/invites_table.go | 2 +- syncapi/storage/postgres/memberships_table.go | 2 +- .../postgres/output_room_events_table.go | 2 +- .../sqlite3/current_room_state_table.go | 4 +- syncapi/storage/sqlite3/invites_table.go | 2 +- syncapi/storage/sqlite3/memberships_table.go | 2 +- .../sqlite3/output_room_events_table.go | 2 +- syncapi/storage/storage_test.go | 1 + .../storage/tables/current_room_state_test.go | 8 +- syncapi/storage/tables/memberships_test.go | 2 + syncapi/streams/stream_pdu.go | 131 +++++++++++++++++- 41 files changed, 593 insertions(+), 177 deletions(-) diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index bafc37b67..60b120b9c 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -22,10 +22,6 @@ import ( "time" "github.com/getsentry/sentry-go" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/fclient" - "github.com/matrix-org/gomatrixserverlib/spec" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -36,6 +32,9 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -433,11 +432,6 @@ func buildMembershipEvent( return nil, err } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) - if err != nil { - return nil, err - } - userID, err := spec.NewUserID(device.UserID, true) if err != nil { return nil, err @@ -459,6 +453,12 @@ func buildMembershipEvent( if err != nil { return nil, err } + + identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *userID) + if err != nil { + return nil, err + } + return buildMembershipEventDirect(ctx, targetSenderID, reason, profile.DisplayName, profile.AvatarURL, senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index c89ece41f..35da15e0e 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -145,7 +145,7 @@ func SetAvatarURL( } } - response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, cfg, evTime) + response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, evTime) if err != nil { return response } @@ -234,7 +234,7 @@ func SetDisplayName( } } - response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, cfg, evTime) + response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, evTime) if err != nil { return response } @@ -248,7 +248,7 @@ func SetDisplayName( func updateProfile( ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device, profile *authtypes.Profile, - userID string, cfg *config.ClientAPI, evTime time.Time, + userID string, evTime time.Time, ) (util.JSONResponse, error) { var res api.QueryRoomsForUserResponse err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ @@ -273,7 +273,7 @@ func updateProfile( } events, err := buildMembershipEvents( - ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, + ctx, res.RoomIDs, *profile, userID, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -344,9 +344,8 @@ func getProfile( func buildMembershipEvents( ctx context.Context, - device *userapi.Device, roomIDs []string, - newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, + newProfile authtypes.Profile, userID string, evTime time.Time, rsAPI api.ClientRoomserverAPI, ) ([]*types.HeaderedEvent, error) { evs := []*types.HeaderedEvent{} @@ -383,12 +382,17 @@ func buildMembershipEvents( return nil, err } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + user, err := spec.NewUserID(userID, true) if err != nil { return nil, err } - event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, evTime, rsAPI, nil) + identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *user) + if err != nil { + return nil, err + } + + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 42f029395..1b9a5a818 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -150,7 +150,7 @@ func SendRedaction( } } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + identity, err := rsAPI.SigningIdentityFor(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -159,7 +159,7 @@ func SendRedaction( } var queryRes roomserverAPI.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &proto, identity, time.Now(), rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &proto, &identity, time.Now(), rsAPI, &queryRes) if errors.Is(err, eventutil.ErrRoomNoExists{}) { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index d51a570de..41a3793ae 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -23,12 +23,6 @@ import ( "sync" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/spec" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/transactions" @@ -36,6 +30,11 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" ) // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid @@ -68,6 +67,8 @@ var sendEventDuration = prometheus.NewHistogramVec( // /rooms/{roomID}/send/{eventType} // /rooms/{roomID}/send/{eventType}/{txnID} // /rooms/{roomID}/state/{eventType}/{stateKey} +// +// nolint: gocyclo func SendEvent( req *http.Request, device *userapi.Device, @@ -121,6 +122,17 @@ func SendEvent( delete(r, "join_authorised_via_users_server") } + // for power level events we need to replace the userID with the pseudoID + if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels { + err = updatePowerLevels(req, r, roomID, rsAPI) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: err.Error()}, + } + } + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -129,7 +141,7 @@ func SendEvent( } } - e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, cfg, rsAPI, evTime) + e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime) if resErr != nil { return *resErr } @@ -225,6 +237,28 @@ func SendEvent( return res } +func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error { + userMap := r["users"].(map[string]interface{}) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + for user, level := range userMap { + uID, err := spec.NewUserID(user, true) + if err != nil { + continue // we're modifying the map in place, so we're going to have invalid userIDs after the first iteration + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *uID) + if err != nil { + return err + } + userMap[string(senderID)] = level + delete(userMap, user) + } + r["users"] = userMap + return nil +} + // stateEqual compares the new and the existing state event content. If they are equal, returns a *util.JSONResponse // with the existing event_id, making this an idempotent request. func stateEqual(ctx context.Context, rsAPI api.ClientRoomserverAPI, eventType, stateKey, roomID string, newContent map[string]interface{}) *util.JSONResponse { @@ -261,7 +295,6 @@ func generateSendEvent( r map[string]interface{}, device *userapi.Device, roomID, eventType string, stateKey *string, - cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, evTime time.Time, ) (gomatrixserverlib.PDU, *util.JSONResponse) { @@ -304,7 +337,7 @@ func generateSendEvent( } } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *fullUserID) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, @@ -313,7 +346,7 @@ func generateSendEvent( } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, &queryRes) switch specificErr := err.(type) { case nil: case eventutil.ErrRoomNoExists: diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 7006ced46..66258a68a 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -221,7 +221,7 @@ func SendServerNotice( "body": r.Content.Body, "msgtype": r.Content.MsgType, } - e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now()) + e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, rsAPI, time.Now()) if resErr != nil { logrus.Errorf("failed to send message: %+v", resErr) return *resErr @@ -350,7 +350,7 @@ func getSenderDevice( if len(deviceRes.Devices) > 0 { // If there were changes to the profile, create a new membership event if displayNameChanged || avatarChanged { - _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now()) + _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, time.Now()) if err != nil { return nil, err } diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index c6ad3f748..6dd2fd345 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -192,7 +192,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew evs[i] = addsStateEvents[i].PDU } - addsJoinedHosts, err := JoinedHostsFromEvents(evs) + addsJoinedHosts, err := JoinedHostsFromEvents(s.ctx, evs, s.rsAPI) if err != nil { return err } @@ -345,7 +345,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( return nil, err } - combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents) + combinedAddsJoinedHosts, err := JoinedHostsFromEvents(s.ctx, combinedAddsEvents, s.rsAPI) if err != nil { return nil, err } @@ -394,7 +394,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( // JoinedHostsFromEvents turns a list of state events into a list of joined hosts. // This errors if one of the events was invalid. // It should be impossible for an invalid event to get this far in the pipeline. -func JoinedHostsFromEvents(evs []gomatrixserverlib.PDU) ([]types.JoinedHost, error) { +func JoinedHostsFromEvents(ctx context.Context, evs []gomatrixserverlib.PDU, rsAPI api.FederationRoomserverAPI) ([]types.JoinedHost, error) { var joinedHosts []types.JoinedHost for _, ev := range evs { if ev.Type() != "m.room.member" || ev.StateKey() == nil { @@ -407,12 +407,17 @@ func JoinedHostsFromEvents(evs []gomatrixserverlib.PDU) ([]types.JoinedHost, err if membership != spec.Join { continue } - _, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + validRoomID, err := spec.NewRoomID(ev.RoomID()) if err != nil { return nil, err } + userID, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey())) + if err != nil { + return nil, err + } + joinedHosts = append(joinedHosts, types.JoinedHost{ - MemberEventID: ev.EventID(), ServerName: serverName, + MemberEventID: ev.EventID(), ServerName: userID.Domain(), }) } return joinedHosts, nil diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 7f61dba41..515b3377d 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -2,6 +2,7 @@ package internal import ( "context" + "crypto/ed25519" "encoding/json" "errors" "fmt" @@ -170,13 +171,24 @@ func (r *FederationInternalAPI) performJoinUsingServer( UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, - SenderIDCreator: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (spec.SenderID, error) { + GetOrCreateSenderID: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) { + // assign a roomNID, otherwise we can't create a private key for the user + _, nidErr := r.rsAPI.AssignRoomNID(ctx, roomID, gomatrixserverlib.RoomVersion(roomVersion)) + if nidErr != nil { + return "", nil, nidErr + } key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) if keyErr != nil { - return "", keyErr + return "", nil, keyErr } - - return spec.SenderID(spec.Base64Bytes(key).Encode()), nil + return spec.SenderIDFromPseudoIDKey(key), key, nil + }, + StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error { + storeUserID, userErr := spec.NewUserID(userIDRaw, true) + if userErr != nil { + return userErr + } + return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID) }, } response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) @@ -200,7 +212,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // joining a room, waiting for 200 OK then changing device keys and have those keys not be sent // to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers") // The events are trusted now as we performed auth checks above. - joinedHosts, err := consumers.JoinedHostsFromEvents(response.StateSnapshot.GetStateEvents().TrustedEvents(response.JoinEvent.Version(), false)) + joinedHosts, err := consumers.JoinedHostsFromEvents(ctx, response.StateSnapshot.GetStateEvents().TrustedEvents(response.JoinEvent.Version(), false), r.rsAPI) if err != nil { return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err) } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 7aa50f65a..bfa1ba8b8 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -15,6 +15,7 @@ package routing import ( + "context" "fmt" "net/http" "sort" @@ -107,6 +108,10 @@ func MakeJoin( } } + if senderID == "" { + senderID = spec.SenderID(userID.String()) + } + input := gomatrixserverlib.HandleMakeJoinInput{ Context: httpReq.Context(), UserID: userID, @@ -218,6 +223,13 @@ func SendJoin( UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, + StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error { + userID, userErr := spec.NewUserID(userIDRaw, true) + if userErr != nil { + return userErr + } + return rsAPI.StoreUserRoomPublicKey(ctx, senderID, *userID, roomID) + }, } response, joinErr := gomatrixserverlib.HandleSendJoin(input) switch e := joinErr.(type) { diff --git a/go.mod b/go.mod index 930db3958..f43760e31 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -43,6 +43,7 @@ require ( github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 golang.org/x/crypto v0.10.0 + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/sync v0.1.0 @@ -124,7 +125,6 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect go.etcd.io/bbolt v1.3.6 // indirect - golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.8.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.9.0 // indirect diff --git a/go.sum b/go.sum index cf6993938..e261f551f 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1 h1:k75Fy0iQVbDjvddip/x898+BdyopBNAfL1BMNx0awA0= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093 h1:FHd3SYhU2ZxZhkssZ/7ms5+M2j+g94lYp8ztvA1E6tA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index e2dd5dd73..ab56529c5 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" @@ -73,6 +74,7 @@ type RoomserverInternalAPI interface { type UserRoomPrivateKeyCreator interface { // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) + StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error } type InputRoomEventsAPI interface { @@ -184,6 +186,7 @@ type ClientRoomserverAPI interface { QueryBulkStateContentAPI QueryEventsAPI QuerySenderIDAPI + UserRoomPrivateKeyCreator QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error @@ -213,6 +216,7 @@ type ClientRoomserverAPI interface { PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error RemoveRoomAlias(ctx context.Context, req *RemoveRoomAliasRequest, res *RemoveRoomAliasResponse) error + SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) } type UserRoomserverAPI interface { @@ -232,7 +236,8 @@ type FederationRoomserverAPI interface { QueryBulkStateContentAPI QuerySenderIDAPI UserRoomPrivateKeyCreator - + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) + SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 684a5b0e3..b6140afd5 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -174,6 +174,8 @@ type QueryServerJoinedToRoomResponse struct { RoomExists bool `json:"room_exists"` // True if we still believe that the server is participating in the room IsInRoom bool `json:"is_in_room"` + // The roomversion if joined to room + RoomVersion gomatrixserverlib.RoomVersion } // QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index e6fb73383..b04a56fe8 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -115,6 +115,7 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID( // nolint:gocyclo // RemoveRoomAlias implements alias.RoomserverInternalAPI +// nolint: gocyclo func (r *RoomserverInternalAPI) RemoveRoomAlias( ctx context.Context, request *api.RemoveRoomAliasRequest, @@ -188,9 +189,11 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - senderDomain := sender.Domain() - - identity, err := r.Cfg.Global.SigningIdentityFor(senderDomain) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + identity, err := r.SigningIdentityFor(ctx, *validRoomID, *sender) if err != nil { return err } @@ -216,7 +219,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - newEvent, err := eventutil.BuildEvent(ctx, proto, identity, time.Now(), &eventsNeeded, stateRes) + newEvent, err := eventutil.BuildEvent(ctx, proto, &identity, time.Now(), &eventsNeeded, stateRes) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 7943ae5c0..2e12671ff 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -6,6 +6,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -110,11 +111,6 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing - identity, err := r.Cfg.Global.SigningIdentityFor(r.ServerName) - if err != nil { - logrus.Panic(err) - } - r.Inputer = &input.Inputer{ Cfg: &r.Cfg.RoomServer, ProcessContext: r.ProcessContext, @@ -125,7 +121,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio NATSClient: r.NATSClient, Durable: nats.Durable(r.Durable), ServerName: r.ServerName, - SigningIdentity: identity, + SigningIdentity: r.SigningIdentityFor, FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, @@ -292,3 +288,45 @@ func (r *RoomserverInternalAPI) GetOrCreateUserRoomPrivateKey(ctx context.Contex } return key, nil } + +func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error { + pubKeyBytes, err := senderID.RawBytes() + if err != nil { + return err + } + _, err = r.DB.InsertUserRoomPublicKey(ctx, userID, roomID, ed25519.PublicKey(pubKeyBytes)) + return err +} + +func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) { + roomVersion, ok := r.Cache.GetRoomVersion(roomID.String()) + if !ok { + roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) + if err != nil { + return fclient.SigningIdentity{}, err + } + if roomInfo != nil { + roomVersion = roomInfo.RoomVersion + } + } + if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + privKey, err := r.GetOrCreateUserRoomPrivateKey(ctx, senderID, roomID) + if err != nil { + return fclient.SigningIdentity{}, err + } + return fclient.SigningIdentity{ + PrivateKey: privKey, + KeyID: "ed25519:1", + ServerName: "self", + }, nil + } + identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain()) + if err != nil { + return fclient.SigningIdentity{}, err + } + return *identity, err +} + +func (r *RoomserverInternalAPI) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) { + return r.DB.AssignRoomNID(ctx, roomID, roomVersion) +} diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 3db2d0a67..dea8f8c87 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -81,7 +81,7 @@ type Inputer struct { JetStream nats.JetStreamContext Durable nats.SubOpt ServerName spec.ServerName - SigningIdentity *fclient.SigningIdentity + SigningIdentity func(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index aa05d9594..db3c95502 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -406,7 +406,7 @@ func (r *Inputer) processRoomEvent( ) if !isRejected && !isCreateEvent { resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer) - redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver) + redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver, r.Queryer) if err != nil { return err } @@ -895,7 +895,22 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r return err } - event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + + userID, err := spec.NewUserID(stateKey, true) + if err != nil { + return err + } + + signingIdentity, err := r.SigningIdentity(ctx, *validRoomID, *userID) + if err != nil { + return err + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, &signingIdentity, time.Now(), &eventsNeeded, latestRes) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 3fdc8e4d0..33200e819 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -647,7 +647,7 @@ func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySe resolver := state.NewStateResolution(db, roomInfo, querier) - _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver) + _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver, querier) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") continue diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index dcaf8dca6..8c9656453 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) const ( @@ -64,6 +65,16 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } } + + _, err = c.DB.AssignRoomNID(ctx, roomID, createRequest.RoomVersion) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to assign roomNID") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + var senderID spec.SenderID if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { // create user room key if needed @@ -75,7 +86,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo JSON: spec.InternalServerError{}, } } - senderID = spec.SenderID(spec.Base64Bytes(key.Public().(ed25519.PublicKey)).Encode()) + senderID = spec.SenderIDFromPseudoIDKey(key) } else { senderID = spec.SenderID(userID.String()) } @@ -138,13 +149,59 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo membershipEvent := gomatrixserverlib.FledglingEvent{ Type: spec.MRoomMember, StateKey: string(senderID), - Content: gomatrixserverlib.MemberContent{ - Membership: spec.Join, - DisplayName: createRequest.UserDisplayName, - AvatarURL: createRequest.UserAvatarURL, - }, } + memberContent := gomatrixserverlib.MemberContent{ + Membership: spec.Join, + DisplayName: createRequest.UserDisplayName, + AvatarURL: createRequest.UserAvatarURL, + } + + // get the signing identity + identity, err := c.Cfg.Matrix.SigningIdentityFor(userID.Domain()) // we MUST use the server signing mxid_mapping + if err != nil { + logrus.WithError(err).WithField("domain", userID.Domain()).Error("unable to find signing identity for domain") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // If we are creating a room with pseudo IDs, create and sign the MXIDMapping + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + var pseudoIDKey ed25519.PrivateKey + pseudoIDKey, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + mapping := &gomatrixserverlib.MXIDMapping{ + UserRoomKey: spec.SenderIDFromPseudoIDKey(pseudoIDKey), + UserID: userID.String(), + } + + // Sign the mapping with the server identity + if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil { + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + memberContent.MXIDMapping = mapping + + // sign all events with the pseudo ID key + identity = &fclient.SigningIdentity{ + ServerName: "self", + KeyID: "ed25519:1", + PrivateKey: pseudoIDKey, + } + } + membershipEvent.Content = memberContent + var nameEvent *gomatrixserverlib.FledglingEvent var topicEvent *gomatrixserverlib.FledglingEvent var guestAccessEvent *gomatrixserverlib.FledglingEvent @@ -322,7 +379,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo JSON: spec.InternalServerError{}, } } - ev, err = builder.Build(createRequest.EventTime, userID.Domain(), createRequest.KeyID, createRequest.PrivateKey) + ev, err = builder.Build(createRequest.EventTime, identity.ServerName, identity.KeyID, identity.PrivateKey) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildEvent failed") return "", &util.JSONResponse{ @@ -363,17 +420,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo }) } - // first send the `m.room.create` event, so we have a roomNID - if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[:1], false); err != nil { - util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") - return "", &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - - // send the remaining events - if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil { + // send the events to the roomserver + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs, false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, @@ -483,11 +531,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } // Build the invite event. - identity := &fclient.SigningIdentity{ - ServerName: userID.Domain(), - KeyID: createRequest.KeyID, - PrivateKey: createRequest.PrivateKey, - } inviteEvent, err = eventutil.QueryAndBuildEvent(ctx, &proto, identity, createRequest.EventTime, c.RSAPI, nil) if err != nil { diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index babd5f812..f19a508a3 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -153,6 +153,23 @@ func (r *Inviter) PerformInvite( } isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) + // If we're inviting a local user, we can generate the needed pseudoID key here. (if needed) + if isTargetLocal { + var roomVersion gomatrixserverlib.RoomVersion + roomVersion, err = r.DB.GetRoomVersion(ctx, event.RoomID()) + if err != nil { + return err + } + + switch roomVersion { + case gomatrixserverlib.RoomVersionPseudoIDs: + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *invitedUser, *validRoomID) + if err != nil { + return err + } + } + } + invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser) if err != nil { return fmt.Errorf("failed looking up senderID for invited user") diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 5867ee6e0..c14554640 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "crypto/ed25519" "database/sql" "errors" "fmt" @@ -24,6 +25,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -202,14 +204,15 @@ func (r *Joiner) performJoinRoomByID( senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID) if err == nil { checkInvitePending = true - } else { + } + if senderID == "" { // create user room key if needed key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) if keyErr != nil { util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed") return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr) } - senderID = spec.SenderID(spec.Base64Bytes(key).Encode()) + senderID = spec.SenderIDFromPseudoIDKey(key) } default: checkInvitePending = true @@ -283,11 +286,39 @@ func (r *Joiner) performJoinRoomByID( // but everyone has since left. I suspect it does the wrong thing. var buildRes rsAPI.QueryLatestEventsAndStateResponse - identity, err := r.Cfg.Matrix.SigningIdentityFor(userDomain) + identity, err := r.RSAPI.SigningIdentityFor(ctx, *roomID, *userID) if err != nil { return "", "", fmt.Errorf("error joining local room: %q", err) } + // at this point we know we have an existing room + if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + var pseudoIDKey ed25519.PrivateKey + pseudoIDKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", err + } + + mapping := &gomatrixserverlib.MXIDMapping{ + UserRoomKey: spec.SenderIDFromPseudoIDKey(pseudoIDKey), + UserID: userID.String(), + } + + // Sign the mapping with the server identity + if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil { + return "", "", err + } + req.Content["mxid_mapping"] = mapping + + // sign the event with the pseudo ID key + identity = fclient.SigningIdentity{ + ServerName: "self", + KeyID: "ed25519:1", + PrivateKey: pseudoIDKey, + } + } + senderIDString := string(senderID) // Prepare the template for the join event. @@ -317,7 +348,7 @@ func (r *Joiner) performJoinRoomByID( if err = proto.SetContent(req.Content); err != nil { return "", "", fmt.Errorf("eb.SetContent: %w", err) } - event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes) + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes) switch err.(type) { case nil: diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index e1ddb9b50..a20896cf7 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -177,12 +177,17 @@ func (r *Leaver) performLeaveRoomByID( // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. + validRoomID, err := spec.NewRoomID(req.RoomID) + if err != nil { + return nil, err + } + var buildRes rsAPI.QueryLatestEventsAndStateResponse - identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.Domain()) + identity, err := r.RSAPI.SigningIdentityFor(ctx, *validRoomID, req.Leaver) if err != nil { return nil, fmt.Errorf("SigningIdentityFor: %w", err) } - event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes) + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes) if err != nil { return nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 19fd456b5..918619e5e 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -478,6 +478,9 @@ func (r *Queryer) QueryServerJoinedToRoom( if err != nil { return fmt.Errorf("r.DB.RoomInfo: %w", err) } + if info != nil { + response.RoomVersion = info.RoomVersion + } if info == nil || info.IsStub() { return nil } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 077957fa1..76b21ad23 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -35,6 +35,14 @@ import ( "github.com/matrix-org/dendrite/test/testrig" ) +type FakeQuerier struct { + api.QuerySenderIDAPI +} + +func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func TestUsers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) @@ -566,7 +574,7 @@ func TestRedaction(t *testing.T) { err = updater.Commit() assert.NoError(t, err) - _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver) + _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver, &FakeQuerier{}) assert.NoError(t, err) if redactedEvent != nil { assert.Equal(t, ev.Redacts(), redactedEvent.EventID()) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7156c11cc..e9b4609ec 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -18,6 +18,7 @@ import ( "context" "crypto/ed25519" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -190,7 +191,7 @@ type Database interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) MaybeRedactEvent( - ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) } @@ -251,7 +252,7 @@ type EventDatabase interface { // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error // (nil if there was nothing to do) MaybeRedactEvent( - ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 61a3520a4..fc3ace6a6 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -10,6 +10,7 @@ import ( "sort" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" @@ -991,6 +992,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) ( // Returns the redaction event and the redacted event if this call resulted in a redaction. func (d *EventDatabase) MaybeRedactEvent( ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, + querier api.QuerySenderIDAPI, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) { var ( redactionEvent, redactedEvent *types.Event @@ -1030,15 +1032,18 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } - // TODO: Don't hack senderID into userID here (pseudoIDs) + var validRoomID *spec.RoomID + validRoomID, err = spec.NewRoomID(redactedEvent.RoomID()) + if err != nil { + return err + } sender1Domain := "" - sender1, err1 := spec.NewUserID(string(redactedEvent.SenderID()), true) + sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEvent.SenderID()) if err1 == nil { sender1Domain = string(sender1.Domain()) } - // TODO: Don't hack senderID into userID here (pseudoIDs) sender2Domain := "" - sender2, err2 := spec.NewUserID(string(redactionEvent.SenderID()), true) + sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) if err2 == nil { sender2Domain = string(sender2.Domain()) } @@ -1698,6 +1703,7 @@ func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userID spec.User // SelectUserRoomPrivateKey queries the users room private key. // If no key exists, returns no key and no error. Otherwise returns // the key and a database error, if any. +// TODO: Cache this? func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) { uID := userID.String() stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) @@ -1756,58 +1762,54 @@ func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.User // SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { result = make(map[spec.RoomID]map[string]string, len(publicKeys)) - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // map all roomIDs to roomNIDs - query := make(map[types.RoomNID][]ed25519.PublicKey) - rooms := make(map[types.RoomNID]spec.RoomID) - for roomID, keys := range publicKeys { - roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String()) - if !ok { - roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) - if rErr != nil { - return rErr - } - if roomInfo == nil { - logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String()) - continue - } - roomNID = roomInfo.RoomNID + // map all roomIDs to roomNIDs + query := make(map[types.RoomNID][]ed25519.PublicKey) + rooms := make(map[types.RoomNID]spec.RoomID) + for roomID, keys := range publicKeys { + roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String()) + if !ok { + roomInfo, rErr := d.roomInfo(ctx, nil, roomID.String()) + if rErr != nil { + return nil, rErr } - - query[roomNID] = keys - rooms[roomNID] = roomID - } - - // get the user room key pars - userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, query) - if sErr != nil { - return sErr - } - nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap)) - for _, nid := range userRoomKeyPairMap { - nids = append(nids, nid.EventStateKeyNID) - } - // get the userIDs - nidMap, seErr := d.EventStateKeys(ctx, nids) - if seErr != nil { - return seErr - } - - // build the result map (roomID -> map publicKey -> userID) - for publicKey, userRoomKeyPair := range userRoomKeyPairMap { - userID := nidMap[userRoomKeyPair.EventStateKeyNID] - roomID := rooms[userRoomKeyPair.RoomNID] - resMap, exists := result[roomID] - if !exists { - resMap = map[string]string{} + if roomInfo == nil { + logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String()) + continue } - resMap[publicKey] = userID - result[roomID] = resMap + roomNID = roomInfo.RoomNID } - return nil - }) + query[roomNID] = keys + rooms[roomNID] = roomID + } + + // get the user room key pars + userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, nil, query) + if sErr != nil { + return nil, sErr + } + nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap)) + for _, nid := range userRoomKeyPairMap { + nids = append(nids, nid.EventStateKeyNID) + } + // get the userIDs + nidMap, seErr := d.EventStateKeys(ctx, nids) + if seErr != nil { + return nil, seErr + } + + // build the result map (roomID -> map publicKey -> userID) + for publicKey, userRoomKeyPair := range userRoomKeyPairMap { + userID := nidMap[userRoomKeyPair.EventStateKeyNID] + roomID := rooms[userRoomKeyPair.RoomNID] + resMap, exists := result[roomID] + if !exists { + resMap = map[string]string{} + } + resMap[publicKey] = userID + result[roomID] = resMap + } return result, err } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index d58b8ac3f..5d6ddc9a8 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -57,6 +57,7 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` type userRoomKeysStatements struct { + db *sql.DB insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt @@ -70,7 +71,7 @@ func CreateUserRoomKeysTable(db *sql.DB) error { } func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { - s := &userRoomKeysStatements{} + s := &userRoomKeysStatements{db: db} return s, sqlutil.StatementList{ {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, @@ -137,7 +138,7 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1) selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs - selectStmt, err := txn.Prepare(selectSQL) + selectStmt, err := s.db.Prepare(selectSQL) if err != nil { return nil, err } diff --git a/roomserver/types/headered_event.go b/roomserver/types/headered_event.go index 52d006bd9..783999822 100644 --- a/roomserver/types/headered_event.go +++ b/roomserver/types/headered_event.go @@ -18,6 +18,7 @@ import ( "unsafe" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // HeaderedEvent is an Event which serialises to the headered form, which includes @@ -25,6 +26,10 @@ import ( type HeaderedEvent struct { gomatrixserverlib.PDU Visibility gomatrixserverlib.HistoryVisibility + // TODO: Remove this. This is a temporary workaround to store the userID in the syncAPI. + // It really should be the userKey instead. + UserID spec.UserID + StateKeyResolved *string } func (h *HeaderedEvent) CacheCost() int { diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 90f9ff67d..e6b5ddbb0 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -256,16 +256,19 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } } - pduPos, err := s.db.WriteEvent( - ctx, - ev, - addsStateEvents, - msg.AddsStateEventIDs, - msg.RemovesStateEventIDs, - msg.TransactionID, - false, - msg.HistoryVisibility, - ) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return err + } + + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, ev.SenderID()) + if err != nil { + return err + } + + ev.UserID = *userID + + pduPos, err := s.db.WriteEvent(ctx, ev, addsStateEvents, msg.AddsStateEventIDs, msg.RemovesStateEventIDs, msg.TransactionID, false, msg.HistoryVisibility) if err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ @@ -315,16 +318,19 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( // hack but until we have some better strategy for dealing with // old events in the sync API, this should at least prevent us // from confusing clients into thinking they've joined/left rooms. - pduPos, err := s.db.WriteEvent( - ctx, - ev, - []*rstypes.HeaderedEvent{}, - []string{}, // adds no state - []string{}, // removes no state - nil, // no transaction - ev.StateKey() != nil, // exclude from sync?, - msg.HistoryVisibility, - ) + + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return err + } + + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, ev.SenderID()) + if err != nil { + return err + } + ev.UserID = *userID + + pduPos, err := s.db.WriteEvent(ctx, ev, []*rstypes.HeaderedEvent{}, []string{}, []string{}, nil, ev.StateKey() != nil, msg.HistoryVisibility) if err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ @@ -420,6 +426,8 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( return } + msg.Event.UserID = *userID + pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) if err != nil { sentry.CaptureException(err) @@ -537,6 +545,7 @@ func (s *OutputRoomEventConsumer) onPurgeRoom( } func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) (*rstypes.HeaderedEvent, error) { + event.StateKeyResolved = event.StateKey() if event.StateKey() == nil { return event, nil } @@ -556,6 +565,29 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) return event, err } + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return event, err + } + + if event.StateKey() != nil { + if *event.StateKey() != "" { + var sku *spec.UserID + sku, err = s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, spec.SenderID(stateKey)) + if err == nil && sku != nil { + sKey := sku.String() + event.StateKeyResolved = &sKey + } + } + } + + userID, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, event.SenderID()) + if err != nil { + return event, err + } + + event.UserID = *userID + if prevEvent == nil || prevEvent.EventID() == event.EventID() { return event, nil } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index f6d7fb4eb..905a9a1ac 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -230,6 +230,7 @@ func TestSearch(t *testing.T) { stateEvents = append(stateEvents, x) stateEventIDs = append(stateEventIDs, x.EventID()) } + x.StateKeyResolved = x.StateKey() sp, err = db.WriteEvent(processCtx.Context(), x, stateEvents, stateEventIDs, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared) assert.NoError(t, err) if x.Type() != "m.room.message" { diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index bfe5e9bdd..112fa9d4a 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -343,9 +343,9 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.SenderID(), + event.UserID.String(), containsURL, - *event.StateKey(), + *event.StateKeyResolved, headeredJSON, membership, addedAt, diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index 267209bba..7b8d2d733 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -101,7 +101,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( ctx, inviteEvent.RoomID(), inviteEvent.EventID(), - *inviteEvent.StateKey(), + inviteEvent.UserID.String(), headeredJSON, ).Scan(&streamPos) return diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 3905f9abb..09b47432b 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -109,7 +109,7 @@ func (s *membershipsStatements) UpsertMembership( _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( ctx, event.RoomID(), - *event.StateKey(), + event.StateKeyResolved, membership, event.EventID(), streamPos, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index e068afab1..b58cf59f0 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -407,7 +407,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.SenderID(), + event.UserID.String(), containsURL, pq.StringArray(addState), pq.StringArray(removeState), diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index e432e483b..3bd19b367 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -342,9 +342,9 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.SenderID(), + event.UserID.String(), containsURL, - *event.StateKey(), + *event.StateKeyResolved, headeredJSON, membership, addedAt, diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 347523cf7..7e0d895f1 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -108,7 +108,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( streamPos, inviteEvent.RoomID(), inviteEvent.EventID(), - *inviteEvent.StateKey(), + inviteEvent.UserID.String(), headeredJSON, ) return diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index c09fa1510..a9e880d2a 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -112,7 +112,7 @@ func (s *membershipsStatements) UpsertMembership( _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( ctx, event.RoomID(), - *event.StateKey(), + event.StateKeyResolved, membership, event.EventID(), streamPos, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 5a47aec44..06c65419a 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -348,7 +348,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.SenderID(), + event.UserID.String(), containsURL, string(addStateJSON), string(removeStateJSON), diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index f56e44a30..f57b0d618 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -43,6 +43,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*rstypes.Header var addStateEventIDs []string var removeStateEventIDs []string if ev.StateKey() != nil { + ev.StateKeyResolved = ev.StateKey() addStateEvents = append(addStateEvents, ev) addStateEventIDs = append(addStateEventIDs, ev.EventID()) } diff --git a/syncapi/storage/tables/current_room_state_test.go b/syncapi/storage/tables/current_room_state_test.go index 7d4ec812c..2df111a26 100644 --- a/syncapi/storage/tables/current_room_state_test.go +++ b/syncapi/storage/tables/current_room_state_test.go @@ -54,7 +54,13 @@ func TestCurrentRoomStateTable(t *testing.T) { events := room.CurrentState() err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { for i, ev := range events { - err := tab.UpsertRoomState(ctx, txn, ev, nil, types.StreamPosition(i)) + ev.StateKeyResolved = ev.StateKey() + userID, err := spec.NewUserID(string(ev.SenderID()), true) + if err != nil { + return err + } + ev.UserID = *userID + err = tab.UpsertRoomState(ctx, txn, ev, nil, types.StreamPosition(i)) if err != nil { return fmt.Errorf("failed to UpsertRoomState: %w", err) } diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go index 4afa2ac5b..a421a9772 100644 --- a/syncapi/storage/tables/memberships_test.go +++ b/syncapi/storage/tables/memberships_test.go @@ -80,6 +80,7 @@ func TestMembershipsTable(t *testing.T) { defer cancel() for _, ev := range userEvents { + ev.StateKeyResolved = ev.StateKey() if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil { t.Fatalf("failed to upsert membership: %s", err) } @@ -134,6 +135,7 @@ func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, mem ev := room.CreateAndInsert(t, user, spec.MRoomMember, map[string]interface{}{ "membership": spec.Join, }, test.WithStateKey(user.ID)) + ev.StateKeyResolved = ev.StateKey() // Insert the same event again, but with different positions, which should get updated if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil { t.Fatalf("failed to upsert membership: %s", err) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 7939dd8fa..1a4e5351d 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -3,6 +3,7 @@ package streams import ( "context" "database/sql" + "encoding/json" "fmt" "time" @@ -15,6 +16,8 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" @@ -346,13 +349,40 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // Now that we've filtered the timeline, work out which state events are still // left. Anything that appears in the filtered timeline will be removed from the // "state" section and kept in "timeline". + + // update the powerlevel event for timeline events + for i, ev := range events { + if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { + continue + } + if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { + continue + } + var newEvent gomatrixserverlib.PDU + newEvent, err = p.updatePowerLevelEvent(ctx, ev) + if err != nil { + return r.From, err + } + events[i] = &rstypes.HeaderedEvent{PDU: newEvent} + } + sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)), gomatrixserverlib.TopologicalOrderByAuthEvents, ) delta.StateEvents = make([]*rstypes.HeaderedEvent, len(sEvents)) for i := range sEvents { - delta.StateEvents[i] = sEvents[i].(*rstypes.HeaderedEvent) + ev := sEvents[i] + delta.StateEvents[i] = ev.(*rstypes.HeaderedEvent) + // update the powerlevel event for state events + if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") { + var newEvent gomatrixserverlib.PDU + newEvent, err = p.updatePowerLevelEvent(ctx, ev.(*rstypes.HeaderedEvent)) + if err != nil { + return r.From, err + } + delta.StateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent} + } } if len(delta.StateEvents) > 0 { @@ -421,6 +451,75 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( return latestPosition, nil } +func (p *PDUStreamProvider) updatePowerLevelEvent(ctx context.Context, ev *rstypes.HeaderedEvent) (gomatrixserverlib.PDU, error) { + pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev) + if err != nil { + return nil, err + } + newPls := make(map[string]int64) + var userID *spec.UserID + for user, level := range pls.Users { + validRoomID, _ := spec.NewRoomID(ev.RoomID()) + userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user)) + if err != nil { + return nil, err + } + newPls[userID.String()] = level + } + var newPlBytes, newEv []byte + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(ev.JSON(), "content.users", newPlBytes) + if err != nil { + return nil, err + } + + // do the same for prev content + prevContent := gjson.GetBytes(ev.JSON(), "unsigned.prev_content") + if !prevContent.Exists() { + var evNew gomatrixserverlib.PDU + evNew, err = gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON(newEv, false) + if err != nil { + return nil, err + } + + return evNew, err + } + pls = gomatrixserverlib.PowerLevelContent{} + err = json.Unmarshal([]byte(prevContent.Raw), &pls) + if err != nil { + return nil, err + } + + newPls = make(map[string]int64) + for user, level := range pls.Users { + validRoomID, _ := spec.NewRoomID(ev.RoomID()) + userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user)) + if err != nil { + return nil, err + } + newPls[userID.String()] = level + } + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes) + if err != nil { + return nil, err + } + + var evNew gomatrixserverlib.PDU + evNew, err = gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON(newEv, false) + if err != nil { + return nil, err + } + + return evNew, err +} + // applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make // sure we always return the required events in the timeline. func applyHistoryVisibilityFilter( @@ -470,6 +569,7 @@ func applyHistoryVisibilityFilter( return events, nil } +// nolint: gocyclo func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, snapshot storage.DatabaseTransaction, @@ -563,6 +663,35 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } + // Update powerlevel events for timeline events + for i, ev := range events { + if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { + continue + } + if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { + continue + } + newEvent, err := p.updatePowerLevelEvent(ctx, ev) + if err != nil { + return nil, err + } + events[i] = &rstypes.HeaderedEvent{PDU: newEvent} + } + // Update powerlevel events for state events + for i, ev := range stateEvents { + if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { + continue + } + if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { + continue + } + newEvent, err := p.updatePowerLevelEvent(ctx, ev) + if err != nil { + return nil, err + } + stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent} + } + jr.Timeline.PrevBatch = prevBatch jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) From 939ee325f80c0c57704b8c34e3faa1c7a3927781 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 29 Jun 2023 18:02:11 +0200 Subject: [PATCH 35/35] Actually use the parameter --- roomserver/internal/query/query.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 918619e5e..626d3c13e 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -937,7 +937,7 @@ func (r *Queryer) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types } func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { - res, err := r.DB.GetStateEvent(ctx, roomID.String(), eventType, "") + res, err := r.DB.GetStateEvent(ctx, roomID.String(), eventType, stateKey) if res == nil { return nil, err }