From e388c9ca425641847731225019cf38c25165bca0 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 09:44:16 -0600 Subject: [PATCH] Make senderID query use only roomID --- appservice/consumers/roomserver.go | 4 +-- clientapi/routing/directory.go | 14 ++++++-- clientapi/routing/sendevent.go | 4 +-- clientapi/routing/state.go | 12 +++---- cmd/resolve-state/main.go | 4 +-- federationapi/federationapi_test.go | 2 +- federationapi/internal/perform.go | 12 +++---- federationapi/routing/invite.go | 8 ++--- federationapi/routing/join.go | 8 ++--- federationapi/routing/leave.go | 4 +-- go.mod | 2 +- go.sum | 2 ++ internal/pushrules/evaluate_test.go | 2 +- internal/transactionrequest.go | 4 +-- internal/transactionrequest_test.go | 4 +-- roomserver/api/api.go | 5 ++- roomserver/internal/helpers/auth.go | 4 +-- roomserver/internal/input/input_events.go | 16 +++++----- .../internal/input/input_events_test.go | 2 +- roomserver/internal/input/input_missing.go | 20 ++++++------ roomserver/internal/perform/perform_admin.go | 8 ++--- .../internal/perform/perform_backfill.go | 8 ++--- .../internal/perform/perform_create_room.go | 4 +-- roomserver/internal/perform/perform_invite.go | 4 +-- .../internal/perform/perform_upgrade.go | 8 ++--- roomserver/internal/query/query.go | 24 +++++++------- roomserver/state/state.go | 10 +++--- roomserver/storage/interface.go | 6 ++-- roomserver/storage/shared/room_updater.go | 4 +-- roomserver/storage/shared/storage.go | 4 +-- setup/mscs/msc2836/msc2836.go | 4 +-- setup/mscs/msc2836/msc2836_test.go | 2 +- syncapi/routing/context.go | 20 ++++++------ syncapi/routing/getevent.go | 4 +-- syncapi/routing/memberships.go | 4 +-- syncapi/routing/messages.go | 8 ++--- syncapi/routing/relations.go | 4 +-- syncapi/routing/search.go | 16 +++++----- syncapi/routing/search_test.go | 2 +- syncapi/streams/stream_invite.go | 4 +-- syncapi/streams/stream_pdu.go | 32 +++++++++---------- syncapi/syncapi_test.go | 2 +- syncapi/synctypes/clientevent_test.go | 2 +- syncapi/types/types_test.go | 2 +- test/room.go | 2 +- userapi/consumers/roomserver.go | 12 +++---- userapi/consumers/roomserver_test.go | 2 +- userapi/util/notify_test.go | 2 +- 48 files changed, 174 insertions(+), 163 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 5650a26b7..06625ad7e 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -181,8 +181,8 @@ func (s *OutputRoomEventConsumer) sendEvents( // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), }, ) diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index bc8ef6d65..c702b136d 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -189,7 +189,7 @@ func SetLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), alias, *userID) + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), r.RoomID, *userID) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -238,7 +238,17 @@ func RemoveLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), alias, *userID) + roomIDReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: alias} + roomIDRes := roomserverAPI.GetRoomIDForAliasResponse{} + err = rsAPI.GetRoomIDForAlias(req.Context(), &roomIDReq, &roomIDRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 5fb755f57..8b09f399a 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -331,8 +331,8 @@ func generateSendEvent( stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 5dda21abf..8567e617d 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -142,8 +142,8 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateRes.StateEvents { stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), ) } @@ -166,8 +166,8 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateAfterRes.StateEvents { stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), ) } @@ -339,8 +339,8 @@ func OnIncomingStateTypeRequest( } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), } diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 1ebb6cd8d..360403094 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -183,8 +183,8 @@ func main() { fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return roomserverDB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index fc09440da..a97bcdeab 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,7 +36,7 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } -func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 4f0597983..ef1396f70 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -164,11 +164,11 @@ func (r *FederationInternalAPI) performJoinUsingServer( PrivateKey: r.cfg.Matrix.PrivateKey, KeyID: r.cfg.Matrix.KeyID, KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, } response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) @@ -363,8 +363,8 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - userIDProvider := func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + userIDProvider := func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider, diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 59f8dd1be..d792335b9 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,8 +95,8 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) @@ -188,8 +188,8 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index f09c5aff3..cf6e43032 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -156,8 +156,8 @@ func MakeJoin( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, RoomQuerier: &roomQuerier, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, BuildEventTemplate: createJoinTemplate, } @@ -253,8 +253,8 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } response, joinErr := gomatrixserverlib.HandleSendJoin(input) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index de95330b4..ec7628d30 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -95,8 +95,8 @@ func MakeLeave( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/go.mod b/go.mod index 35f28c03f..1e13d4a10 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154326-77b5ce0c692d github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 8bac1e807..35c3d713a 100644 --- a/go.sum +++ b/go.sum @@ -325,6 +325,8 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30 h1:G+Do1UoWazY0Fetq+eAX1h1+fimf19NGGyaS86hWg8s= github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154326-77b5ce0c692d h1:AUrJcbgtIPtVYTTfV7DUyarW7hOMgsZQZUuy9r8fMv8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154326-77b5ce0c692d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 616dac894..34c1436f4 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 5d9be08c5..0bbe0720c 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,8 +167,8 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = fclient.PDUResult{ diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index bcda6520e..6f3ce0b3b 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,7 +70,7 @@ type FakeRsAPI struct { bannedFromRoom bool } -func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } @@ -642,7 +642,7 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } -func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 38b4d8051..8a115823c 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -67,9 +67,8 @@ type InputRoomEventsAPI interface { } type QuerySenderIDAPI interface { - // Accepts either roomID or alias - QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) - QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) + QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } // Query the latest events and state for a room from the room server. diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index cb5bdcc40..932ce6155 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -76,8 +76,8 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { // return true, nil return true, err diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index ffaa20afa..b692a9513 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -276,8 +276,8 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { isRejected = true rejectionErr = err @@ -581,8 +581,8 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return @@ -694,8 +694,8 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue nextAuthEvent } @@ -712,8 +712,8 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 17a5aa129..0ba7d19f5 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomAliasOrID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { t.Fatalf("event should not be allowed, but it was") } } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 1a3c2c142..ac0670fc3 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -473,8 +473,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -482,8 +482,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: @@ -569,8 +569,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } @@ -660,8 +660,8 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }) if err != nil { return nil, err @@ -897,8 +897,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 488a10648..ca736cb65 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -262,16 +262,16 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index bdb9bf6c9..0f743f4e4 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -121,8 +121,8 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) // Only return an error if we really couldn't get any events. @@ -212,8 +212,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 7448bca4f..bfa1f631e 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -308,8 +308,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return c.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return c.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return "", &util.JSONResponse{ diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 4050c294f..e8e20ede2 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -156,8 +156,8 @@ func (r *Inviter) PerformInvite( StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, StateQuerier: &QueryState{r.DB}, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 99ad21ccd..7513aab46 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -484,8 +484,8 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user } - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) } @@ -569,8 +569,8 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index fdc37c7b2..7828c8b8c 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -122,8 +122,8 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -351,8 +351,8 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -402,8 +402,8 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -594,8 +594,8 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -1043,10 +1043,10 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query return nil } -func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { - return r.DB.GetSenderIDForUser(ctx, roomAliasOrID, userID) +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { + return r.DB.GetSenderIDForUser(ctx, roomID, userID) } -func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 1607971f7..3131cbff2 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -44,7 +44,7 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) - GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type StateResolution struct { @@ -947,8 +947,8 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) }) // Map from the full events back to numeric state entries. @@ -1061,8 +1061,8 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, - func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) }, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 89e40b49b..2d007bed5 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -167,9 +167,9 @@ type Database interface { // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownUsers tries to obtain the current mxid for a given user. - GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) // GetKnownUsers tries to obtain the current senderID for a given user. - GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) + GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -215,7 +215,7 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) - GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 4068caebc..735001383 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -252,6 +252,6 @@ func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, ta return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } -func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { - return u.d.GetUserIDForSender(ctx, roomAliasOrID, senderID) +func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return u.d.GetUserIDForSender(ctx, roomID, senderID) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index c88f27295..406d7cf1c 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1524,12 +1524,12 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } -func (d *Database) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { // TODO: Use real logic once DB for pseudoIDs is in place return spec.NewUserID(senderID, true) } -func (d *Database) GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { +func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { // TODO: Use real logic once DB for pseudoIDs is in place return userID.String(), nil } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 6e5bc3027..5ce3b430b 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -94,8 +94,8 @@ type MSC2836EventRelationshipsResponse struct { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), Limited: res.Limited, NextBatch: res.NextBatch, diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index c8bc8b509..c463fd72b 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -525,7 +525,7 @@ type testRoomserverAPI struct { events map[string]*types.HeaderedEvent } -func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 5046ba1cd..a646622fc 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -193,19 +193,19 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) newState := state if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { @@ -217,15 +217,15 @@ func Context( } } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), } diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 878c5d90a..0d47477aa 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -103,8 +103,8 @@ func GetEvent( return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 699e08015..9c2319dd9 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -153,8 +153,8 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })}, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 9e7c33d5d..879739d00 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -273,8 +273,8 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })...) } @@ -385,8 +385,8 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), start, end, err } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 3913f2d86..30c02d293 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -116,8 +116,8 @@ func Relations( for _, event := range filteredEvents { res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 991545315..546624cbe 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -228,17 +228,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts Context: SearchContextResponse{ Start: startToken.String(), End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), }) roomGroup := groups[event.RoomID()] @@ -254,8 +254,8 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }) } } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 02410820b..b36be8238 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -25,7 +25,7 @@ import ( type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } -func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index fd7b9dbb4..52bb380c8 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -68,8 +68,8 @@ func (p *InviteStreamProvider) IncrementalSync( if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + ir := types.NewInviteResponse(inviteEvent, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 6bc4bfdf4..8f83a0896 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -376,14 +376,14 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Join[delta.RoomID] = jr @@ -391,12 +391,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Peek[delta.RoomID] = jr @@ -406,14 +406,14 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Leave[delta.RoomID] = lr } @@ -564,14 +564,14 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) return jr, nil } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index ac46e10eb..78c857ab9 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,7 +40,7 @@ type syncRoomserverAPI struct { rooms []*test.Room } -func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 5ae353f79..b76b38c17 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 2cd8e7472..a93e9122f 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/test/room.go b/test/room.go index aefdd8769..4cdb73aa3 100644 --- a/test/room.go +++ b/test/room.go @@ -39,7 +39,7 @@ var ( roomIDCounter = int64(0) ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 2d7bc4814..edc48e52a 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -301,8 +301,8 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) var member *localMembership member, err = newLocalMembership(&cevent) @@ -536,8 +536,8 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID string, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't @@ -659,8 +659,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.PDU, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { return nil, err diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 09e57b895..d677a7ba9 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -47,7 +47,7 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } -func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 778461161..86446b2db 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -23,7 +23,7 @@ import ( userUtil "github.com/matrix-org/dendrite/userapi/util" ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }