diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 1877de37a..e8b9211c4 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -128,7 +128,7 @@ func (s *OutputRoomEventConsumer) onMessage( if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { newEventID := output.NewRoomEvent.Event.EventID() eventsReq := &api.QueryEventsByIDRequest{ - RoomID: output.NewRoomEvent.Event.RoomID(), + RoomID: output.NewRoomEvent.Event.RoomID().String(), EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), } eventsRes := &api.QueryEventsByIDResponse{} @@ -236,11 +236,7 @@ func (s *appserviceState) backoffAndPause(err error) error { // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { user := "" - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return false - } - userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err == nil { user = userID.String() } @@ -250,7 +246,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont return false case appservice.IsInterestedInUserID(user): return true - case appservice.IsInterestedInRoomID(event.RoomID()): + case appservice.IsInterestedInRoomID(event.RoomID().String()): return true } @@ -261,7 +257,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont } // Check all known room aliases of the room the event came from - queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID()} + queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID().String()} var queryRes api.GetAliasesForRoomIDResponse if err := s.rsAPI.GetAliasesForRoomID(ctx, &queryReq, &queryRes); err == nil { for _, alias := range queryRes.Aliases { @@ -272,7 +268,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont } else { log.WithFields(log.Fields{ "appservice": appservice.ID, - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), }).WithError(err).Errorf("Unable to get aliases for room") } @@ -288,7 +284,7 @@ func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, e // until we have a lighter way of checking the state before the event that // doesn't involve state res, then this is probably OK. membershipReq := &api.QueryMembershipsForRoomRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), JoinedOnly: true, } membershipRes := &api.QueryMembershipsForRoomResponse{} @@ -317,7 +313,7 @@ func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, e } else { log.WithFields(log.Fields{ "appservice": appservice.ID, - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), }).WithError(err).Errorf("Unable to get membership for room") } return false diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 230c96d28..aa579db64 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -98,7 +98,7 @@ func SendRedaction( JSON: spec.NotFound("unknown event ID"), // TODO: is it ok to leak existence? } } - if ev.RoomID() != roomID { + if ev.RoomID().String() != roomID { return util.JSONResponse{ Code: 400, JSON: spec.NotFound("cannot redact event in another room"), diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index f81e9c1e4..69131966b 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -263,7 +263,11 @@ func SendEvent( } func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error { - userMap := r["users"].(map[string]interface{}) + users, ok := r["users"] + if !ok { + return nil + } + userMap := users.(map[string]interface{}) validRoomID, err := spec.NewRoomID(roomID) if err != nil { return err @@ -277,7 +281,8 @@ func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID strin if err != nil { return err } else if senderID == nil { - return fmt.Errorf("sender ID not found for %s in %s", uID, *validRoomID) + util.GetLogger(req.Context()).Warnf("sender ID not found for %s in %s", uID, *validRoomID) + continue } userMap[string(*senderID)] = level delete(userMap, user) @@ -437,7 +442,7 @@ func generateSendEvent( JSON: spec.BadJSON("Cannot unmarshal the event content."), } } - if content["replacement_room"] == e.RoomID() { + if content["replacement_room"] == e.RoomID().String() { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("Cannot send tombstone event that points to the same room."), diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 7648dc474..18f9a0e9c 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -172,28 +172,16 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } } for _, ev := range stateAfterRes.StateEvents { - sender := spec.UserID{} - evRoomID, err := spec.NewRoomID(ev.RoomID()) + clientEvent, err := synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { - util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid") + util.GetLogger(ctx).WithError(err).Error("Failed converting to ClientEvent") continue } - userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := ev.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk), + *clientEvent, ) } } diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 6c0580322..f1dcb1175 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -176,7 +176,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew // Finally, work out if there are any more events missing. if len(missingEventIDs) > 0 { eventsReq := &api.QueryEventsByIDRequest{ - RoomID: ore.Event.RoomID(), + RoomID: ore.Event.RoomID().String(), EventIDs: missingEventIDs, } eventsRes := &api.QueryEventsByIDResponse{} @@ -205,7 +205,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew // talking to the roomserver oldJoinedHosts, err := s.db.UpdateRoom( s.ctx, - ore.Event.RoomID(), + ore.Event.RoomID().String(), addsJoinedHosts, ore.RemovesStateEventIDs, rewritesState, // if we're re-writing state, nuke all joined hosts before adding @@ -218,7 +218,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == spec.MRoomMember && ore.Event.StateKey() != nil { membership, _ := ore.Event.Membership() if membership == spec.Join { - s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) + s.sendPresence(ore.Event.RoomID().String(), addsJoinedHosts) } } @@ -376,7 +376,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( } // handle peeking hosts - inboundPeeks, err := s.db.GetInboundPeeks(s.ctx, ore.Event.PDU.RoomID()) + inboundPeeks, err := s.db.GetInboundPeeks(s.ctx, ore.Event.PDU.RoomID().String()) if err != nil { return nil, err } @@ -409,12 +409,8 @@ func JoinedHostsFromEvents(ctx context.Context, evs []gomatrixserverlib.PDU, rsA if membership != spec.Join { continue } - validRoomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - return nil, err - } var domain spec.ServerName - userID, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey())) + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) if err != nil { if errors.As(err, new(base64.CorruptInputError)) { // Fallback to using the "old" way of getting the user domain, avoids @@ -510,7 +506,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( // At this point the missing events are neither the event itself nor are // they present in our local database. Our only option is to fetch them // from the roomserver using the query API. - eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()} + eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID().String()} var eventResp api.QueryEventsByIDResponse if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil { return nil, err diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 4c2a99bbc..1ea8c40ea 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -146,7 +146,7 @@ func (f *fedClient) SendJoin(ctx context.Context, origin, s spec.ServerName, eve f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { - if r.ID == event.RoomID() { + if r.ID == event.RoomID().String() { r.InsertEvent(f.t, &types.HeaderedEvent{PDU: event}) f.t.Logf("Join event: %v", event.EventID()) res.StateEvents = types.NewEventJSONsFromHeaderedEvents(r.CurrentState()) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 3bba3ea0d..0200cf69b 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -548,11 +548,7 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return nil, err - } - inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { return nil, err } @@ -575,7 +571,7 @@ func (r *FederationInternalAPI) SendInvite( logrus.WithFields(logrus.Fields{ "event_id": event.EventID(), "user_id": *event.StateKey(), - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "room_version": event.Version(), "destination": destination, }).Info("Sending invite") diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 26305ed7a..24b3efd2d 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -218,7 +218,7 @@ func (oqs *OutgoingQueues) SendEvent( if api.IsServerBannedFromRoom( oqs.process.Context(), oqs.rsAPI, - ev.RoomID(), + ev.RoomID().String(), destination, ) { delete(destmap, destination) diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index cc38e136f..e75615e05 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -104,7 +104,7 @@ func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u func mustCreatePDU(t *testing.T) *types.HeaderedEvent { t.Helper() - content := `{"type":"m.room.message"}` + content := `{"type":"m.room.message", "room_id":"!room:a"}` ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV10).NewEventFromTrustedJSON([]byte(content), false) if err != nil { t.Fatalf("failed to create event: %v", err) diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 552c4eac2..75a007265 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -109,7 +109,7 @@ func Backfill( var ev *types.HeaderedEvent for _, ev = range res.Events { - if ev.RoomID() == roomID { + if ev.RoomID().String() == roomID { evs = append(evs, ev.PDU) } } diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index c26aa2f15..2be3ecdb1 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -42,10 +42,10 @@ func GetEventAuth( return *resErr } - if event.RoomID() != roomID { + if event.RoomID().String() != roomID { return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } - resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID().String()) if resErr != nil { return *resErr } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index d3f0e81c3..f4659f528 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -42,7 +42,7 @@ func GetEvent( return *err } - err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) + err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID().String()) if err != nil { return *err } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index f28c82115..7c86ba69b 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -211,7 +211,7 @@ func SendLeave( } // Check that the room ID is correct. - if event.RoomID() != roomID { + if event.RoomID().String() != roomID { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON"), @@ -242,14 +242,7 @@ func SendLeave( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("Room ID is invalid."), - } - } - sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID()) + sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID()) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index f57d30204..b1cefe7b4 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -87,7 +87,7 @@ func filterEvents( ) []*types.HeaderedEvent { ref := events[:0] for _, ev := range events { - if ev.RoomID() == roomID { + if ev.RoomID().String() == roomID { ref = append(ref, ev) } } diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 11ad1ebfc..d10910573 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -113,10 +113,10 @@ func getState( return nil, nil, resErr } - if event.RoomID() != roomID { + if event.RoomID().String() != roomID { return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } - resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID().String()) if resErr != nil { return nil, nil, resErr } diff --git a/go.mod b/go.mod index 710b50376..78c1058e7 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230823153616-484e7693bb8d + github.com/matrix-org/gomatrixserverlib v0.0.0-20230926023021-d4830c9bfa49 github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.17 @@ -36,18 +36,18 @@ require ( github.com/prometheus/client_golang v1.16.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.2 - github.com/tidwall/gjson v1.16.0 + github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.12.0 + golang.org/x/crypto v0.13.0 golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/sync v0.3.0 - golang.org/x/term v0.11.0 + golang.org/x/term v0.12.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -124,8 +124,8 @@ require ( go.etcd.io/bbolt v1.3.6 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.12.0 // indirect google.golang.org/protobuf v1.30.0 // indirect diff --git a/go.sum b/go.sum index 863caee72..3cf569497 100644 --- a/go.sum +++ b/go.sum @@ -208,8 +208,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230823153616-484e7693bb8d h1:yFoT2nyjD4TFrgYMJGgrotFbTLjaYKfZbRmnsj7lvZE= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230823153616-484e7693bb8d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230926023021-d4830c9bfa49 h1:o4mdKYYIYCi/QplAjBAJ5kvu3NXXkutZF88gTTpZjj4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230926023021-d4830c9bfa49/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 h1:6t8kJr8i1/1I5nNttw6nn1ryQJgzVlBmSGgPiiaTdw4= github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7/go.mod h1:ReWMS/LoVnOiRAdq9sNUC2NZnd1mZkMNB52QhpTRWjg= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -318,8 +318,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.16.0 h1:SyXa+dsSPpUlcwEDuKuEBJEz5vzTvOea+9rjyYodQFg= -github.com/tidwall/gjson v1.16.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= +github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= @@ -354,8 +354,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= -golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -418,19 +418,19 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0= -golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 56ee576a0..40d62fd68 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -176,15 +176,13 @@ func RedactEvent(ctx context.Context, redactionEvent, redactedEvent gomatrixserv return fmt.Errorf("RedactEvent: redactionEvent isn't a redaction event, is '%s'", redactionEvent.Type()) } redactedEvent.Redact() - validRoomID, err := spec.NewRoomID(redactionEvent.RoomID()) + clientEvent, err := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return querier.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return err } - senderID, err := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) - if err != nil { - return err - } - redactedBecause := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, *senderID, redactionEvent.StateKey()) + redactedBecause := clientEvent if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil { return err } diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 28dea97c4..6baef4347 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -111,15 +111,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati return patternMatches("content.body", *rule.Pattern, event) case RoomKind: - return rule.RuleID == event.RoomID(), nil + return rule.RuleID == event.RoomID().String(), nil case SenderKind: userID := "" - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return false, err - } - sender, err := userIDForSender(*validRoomID, event.SenderID()) + sender, err := userIDForSender(event.RoomID(), event.SenderID()) if err == nil { userID = sender.String() } diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index a4ccc3d0f..fbc88b2e7 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -13,7 +13,7 @@ func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, } func TestRuleSetEvaluatorMatchEvent(t *testing.T) { - ev := mustEventFromJSON(t, `{}`) + ev := mustEventFromJSON(t, `{"room_id":"!room:a"}`) defaultEnabled := &Rule{ RuleID: ".default.enabled", Default: true, @@ -44,8 +44,8 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) { {"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled, ev}, {"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled, ev}, {"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled, ev}, - {"reactions don't notify", *defaultRuleset, &mRuleReactionDefinition, mustEventFromJSON(t, `{"type":"m.reaction"}`)}, - {"receipts don't notify", *defaultRuleset, nil, mustEventFromJSON(t, `{"type":"m.receipt"}`)}, + {"reactions don't notify", *defaultRuleset, &mRuleReactionDefinition, mustEventFromJSON(t, `{"room_id":"!room:a","type":"m.reaction"}`)}, + {"receipts don't notify", *defaultRuleset, nil, mustEventFromJSON(t, `{"room_id":"!room:a","type":"m.receipt"}`)}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { @@ -70,28 +70,27 @@ func TestRuleMatches(t *testing.T) { EventJSON string Want bool }{ - {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, - {"emptyContent", ContentKind, emptyRule, `{}`, false}, - {"emptyRoom", RoomKind, emptyRule, `{}`, true}, + {"emptyOverride", OverrideKind, emptyRule, `{"room_id":"!room:example.com"}`, true}, + {"emptyContent", ContentKind, emptyRule, `{"room_id":"!room:example.com"}`, false}, {"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true}, - {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, + {"emptyUnderride", UnderrideKind, emptyRule, `{"room_id":"!room:example.com"}`, true}, - {"disabled", OverrideKind, Rule{}, `{}`, false}, + {"disabled", OverrideKind, Rule{}, `{"room_id":"!room:example.com"}`, false}, - {"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true}, - {"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + {"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{"room_id":"!room:example.com"}`, true}, + {"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{"room_id":"!room:example.com"}`, false}, - {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, - {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{"room_id":"!room:example.com"}`, true}, + {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{"room_id":"!room:example.com"}`, false}, - {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, - {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"room_id":"!room:example.com","content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"room_id":"!room:example.com","content":{"body":"abc"}}`, false}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, - {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true}, - {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false}, + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"room_id":"!room:example.com","sender":"@user:example.com","room_id":"!room:example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"room_id":"!room:example.com","sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { @@ -114,32 +113,32 @@ func TestConditionMatches(t *testing.T) { WantMatch bool WantErr bool }{ - {Name: "empty", Cond: Condition{}, EventJSON: `{}`, WantMatch: false, WantErr: false}, - {Name: "empty", Cond: Condition{Kind: "unknownstring"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "empty", Cond: Condition{}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: false, WantErr: false}, + {Name: "empty", Cond: Condition{Kind: "unknownstring"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: false, WantErr: false}, // Neither of these should match because `content` is not a full string match, // and `content.body` is not a string value. - {Name: "eventMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, EventJSON: `{"content":{}}`, WantMatch: false, WantErr: false}, - {Name: "eventBodyMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3", Pattern: pointer("")}, EventJSON: `{"content":{"body": "3"}}`, WantMatch: false, WantErr: false}, - {Name: "eventBodyMatch matches", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Pattern: pointer("world")}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: true, WantErr: false}, - {Name: "EventMatch missing pattern", Cond: Condition{Kind: EventMatchCondition, Key: "content.body"}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: false, WantErr: true}, + {Name: "eventMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, EventJSON: `{"room_id":"!room:example.com","content":{}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3", Pattern: pointer("")}, EventJSON: `{"room_id":"!room:example.com","content":{"body": "3"}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch matches", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Pattern: pointer("world")}, EventJSON: `{"room_id":"!room:example.com","content":{"body": "hello world!"}}`, WantMatch: true, WantErr: false}, + {Name: "EventMatch missing pattern", Cond: Condition{Kind: EventMatchCondition, Key: "content.body"}, EventJSON: `{"room_id":"!room:example.com","content":{"body": "hello world!"}}`, WantMatch: false, WantErr: true}, - {Name: "displayNameNoMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"something without displayname"}}`, WantMatch: false, WantErr: false}, - {Name: "displayNameMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"hello Dear User, how are you?"}}`, WantMatch: true, WantErr: false}, + {Name: "displayNameNoMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"room_id":"!room:example.com","content":{"body":"something without displayname"}}`, WantMatch: false, WantErr: false}, + {Name: "displayNameMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"room_id":"!room:example.com","content":{"body":"hello Dear User, how are you?"}}`, WantMatch: true, WantErr: false}, - {Name: "roomMemberCountLessNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, - {Name: "roomMemberCountLessMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<3"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, - {Name: "roomMemberCountLessEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, - {Name: "roomMemberCountLessEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, - {Name: "roomMemberCountEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, - {Name: "roomMemberCountEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, - {Name: "roomMemberCountGreaterEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, - {Name: "roomMemberCountGreaterEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, - {Name: "roomMemberCountGreaterNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, - {Name: "roomMemberCountGreaterMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">1"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountLessNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<2"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<3"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountLessEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==1"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==2"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">2"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">1"}, EventJSON: `{"room_id":"!room:example.com"}`, WantMatch: true, WantErr: false}, - {Name: "senderNotificationPermissionMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@poweruser:example.com"}`, WantMatch: true, WantErr: false}, - {Name: "senderNotificationPermissionNoMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@nobody:example.com"}`, WantMatch: false, WantErr: false}, + {Name: "senderNotificationPermissionMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"room_id":"!room:example.com","sender":"@poweruser:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "senderNotificationPermissionNoMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"room_id":"!room:example.com","sender":"@nobody:example.com"}`, WantMatch: false, WantErr: false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { @@ -170,15 +169,15 @@ func TestPatternMatches(t *testing.T) { EventJSON string Want bool }{ - {"empty", "", "", `{}`, false}, + {"empty", "", "", `{"room_id":"!room:a"}`, false}, - {"patternEmpty", "content", "", `{"content":{}}`, false}, + {"patternEmpty", "content", "", `{"room_id":"!room:a","content":{}}`, false}, - {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, - {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, - {"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true}, - {"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true}, - {"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false}, + {"literal", "content.creator", "acreator", `{"room_id":"!room:a","content":{"creator":"acreator"}}`, true}, + {"substring", "content.creator", "reat", `{"room_id":"!room:a","content":{"creator":"acreator"}}`, true}, + {"singlePattern", "content.creator", "acr?ator", `{"room_id":"!room:a","content":{"creator":"acreator"}}`, true}, + {"multiPattern", "content.creator", "a*ea*r", `{"room_id":"!room:a","content":{"creator":"acreator"}}`, true}, + {"patternNoSubstring", "content.creator", "r*t", `{"room_id":"!room:a","content":{"creator":"acreator"}}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 5bf7d819c..0663c2dcb 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -161,7 +161,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut if event.Type() == spec.MRoomCreate && event.StateKeyEquals("") { continue } - if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { + if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID().String(), t.Origin) { results[event.EventID()] = fclient.PDUResult{ Error: "Forbidden by server ACLs", } diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index b04828b69..601ce9063 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -119,7 +119,7 @@ func (s *ServerACLs) OnServerACLUpdate(state gomatrixserverlib.PDU) { }).Debugf("Updating server ACLs for %q", state.RoomID()) s.aclsMutex.Lock() defer s.aclsMutex.Unlock() - s.acls[state.RoomID()] = acls + s.acls[state.RoomID().String()] = acls } func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID string) bool { diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 2505a993b..0ad5d2013 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -75,7 +75,7 @@ func SendEventWithState( } logrus.WithContext(ctx).WithFields(logrus.Fields{ - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "event_id": event.EventID(), "outliers": len(ires), "state_ids": len(stateEventIDs), diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index df95851e3..d5172dab9 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -85,11 +85,7 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySende continue } - validRoomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - continue - } - userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey)) + userID, err := querier.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) if err != nil { continue } diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index a7f0aab9c..5ceda7e01 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -189,7 +189,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(ctx context.Context, senderID sp proto := &gomatrixserverlib.ProtoEvent{ SenderID: string(canonicalSenderID), - RoomID: ev.RoomID(), + RoomID: ev.RoomID().String(), Type: ev.Type(), StateKey: ev.StateKey(), Content: res, diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 530147daa..1e08f6a3a 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -239,7 +239,7 @@ func (r *RoomserverInternalAPI) HandleInvite( if err != nil { return err } - return r.OutputProducer.ProduceRoomEvents(inviteEvent.RoomID(), outputEvents) + return r.OutputProducer.ProduceRoomEvents(inviteEvent.RoomID().String(), outputEvents) } func (r *RoomserverInternalAPI) PerformCreateRoom( diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 89fae244f..9da751b1a 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -218,9 +218,9 @@ func loadAuthEvents( roomID := "" for _, ev := range result.events { if roomID == "" { - roomID = ev.RoomID() + roomID = ev.RoomID().String() } - if ev.RoomID() != roomID { + if ev.RoomID().String() != roomID { result.valid = false break } diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index febabf411..b2e21bf54 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -54,7 +54,7 @@ func UpdateToInviteMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - RoomID: add.RoomID(), + RoomID: add.RoomID().String(), Membership: spec.Join, RetiredByEventID: add.EventID(), TargetSenderID: spec.SenderID(*add.StateKey()), @@ -396,7 +396,7 @@ BFSLoop: // It's nasty that we have to extract the room ID from an event, but many federation requests // only talk in event IDs, no room IDs at all (!!!) ev := events[0] - isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID()) + isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID().String()) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") } @@ -419,7 +419,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier) + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID().String(), pre, serverName, isServerInRoom, querier) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 990563599..404751532 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -358,7 +358,7 @@ func (r *Inputer) queueInputRoomEvents( // For each event, marshal the input room event and then // send it into the input queue. for _, e := range request.InputRoomEvents { - roomID := e.Event.RoomID() + roomID := e.Event.RoomID().String() subj := r.Cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEventSubj(roomID)) msg := &nats.Msg{ Subject: subj, diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index bf3216623..77b50d0e2 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -87,7 +87,7 @@ func (r *Inputer) processRoomEvent( } trace, ctx := internal.StartRegion(ctx, "processRoomEvent") - trace.SetTag("room_id", input.Event.RoomID()) + trace.SetTag("room_id", input.Event.RoomID().String()) trace.SetTag("event_id", input.Event.EventID()) defer trace.EndRegion() @@ -96,7 +96,7 @@ func (r *Inputer) processRoomEvent( defer func() { timetaken := time.Since(started) processRoomEventDuration.With(prometheus.Labels{ - "room_id": input.Event.RoomID(), + "room_id": input.Event.RoomID().String(), }).Observe(float64(timetaken.Milliseconds())) }() @@ -105,7 +105,7 @@ func (r *Inputer) processRoomEvent( event := headered.PDU logger := util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": event.EventID(), - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "kind": input.Kind, "origin": input.Origin, "type": event.Type(), @@ -120,19 +120,15 @@ func (r *Inputer) processRoomEvent( // Don't waste time processing the event if the room doesn't exist. // A room entry locally will only be created in response to a create // event. - roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID()) + roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID().String()) if rerr != nil { return fmt.Errorf("r.DB.RoomInfo: %w", rerr) } isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("") if roomInfo == nil && !isCreateEvent { - return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) + return fmt.Errorf("room %s does not exist for event %s", event.RoomID().String(), event.EventID()) } - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + sender, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) } @@ -179,7 +175,7 @@ func (r *Inputer) processRoomEvent( // If we have missing events (auth or prev), we build a list of servers to ask if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), ExcludeSelf: true, ExcludeBlacklisted: true, } @@ -395,12 +391,12 @@ func (r *Inputer) processRoomEvent( // Request the room info again — it's possible that the room has been // created by now if it didn't exist already. - roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID()) + roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID().String()) if err != nil { return fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) + return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID().String()) } if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { @@ -459,7 +455,7 @@ func (r *Inputer) processRoomEvent( if userErr != nil { return userErr } - err = r.RSAPI.StoreUserRoomPublicKey(ctx, mapping.MXIDMapping.UserRoomKey, *storeUserID, *validRoomID) + err = r.RSAPI.StoreUserRoomPublicKey(ctx, mapping.MXIDMapping.UserRoomKey, *storeUserID, event.RoomID()) if err != nil { return fmt.Errorf("failed storing user room public key: %w", err) } @@ -481,7 +477,7 @@ func (r *Inputer) processRoomEvent( return fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: - err = r.OutputProducer.ProduceRoomEvents(event.RoomID(), []api.OutputEvent{ + err = r.OutputProducer.ProduceRoomEvents(event.RoomID().String(), []api.OutputEvent{ { Type: api.OutputTypeOldRoomEvent, OldRoomEvent: &api.OutputOldRoomEvent{ @@ -507,7 +503,7 @@ func (r *Inputer) processRoomEvent( // so notify downstream components to redact this event - they should have it if they've // been tracking our output log. if redactedEventID != "" { - err = r.OutputProducer.ProduceRoomEvents(event.RoomID(), []api.OutputEvent{ + err = r.OutputProducer.ProduceRoomEvents(event.RoomID().String(), []api.OutputEvent{ { Type: api.OutputTypeRedactedEvent, RedactedEvent: &api.OutputRedactedEvent{ @@ -536,7 +532,7 @@ func (r *Inputer) processRoomEvent( // handleRemoteRoomUpgrade updates published rooms and room aliases func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { - oldRoomID := event.RoomID() + oldRoomID := event.RoomID().String() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, string(event.SenderID())) } @@ -596,7 +592,7 @@ func (r *Inputer) processStateBefore( StateKey: "", }) stateBeforeReq := &api.QueryStateAfterEventsRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), PrevEventIDs: event.PrevEventIDs(), StateToFetch: tuplesNeeded, } @@ -606,7 +602,7 @@ func (r *Inputer) processStateBefore( } switch { case !stateBeforeRes.RoomExists: - rejectionErr = fmt.Errorf("room %q does not exist", event.RoomID()) + rejectionErr = fmt.Errorf("room %q does not exist", event.RoomID().String()) return case !stateBeforeRes.PrevEventsExist: rejectionErr = fmt.Errorf("prev events of %q are not known", event.EventID()) @@ -707,7 +703,7 @@ func (r *Inputer) fetchAuthEvents( // Request the entire auth chain for the event in question. This should // contain all of the auth events — including ones that we already know — // so we'll need to filter through those in the next section. - res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.Version(), event.RoomID(), event.EventID()) + res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.Version(), event.RoomID().String(), event.EventID()) if err != nil { logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) continue @@ -866,25 +862,20 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) latestReq := &api.QueryLatestEventsAndStateRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), } latestRes := &api.QueryLatestEventsAndStateResponse{} if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { return err } - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - prevEvents := latestRes.LatestEvents for _, memberEvent := range memberEvents { if memberEvent.StateKey() == nil { continue } - memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey())) + memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*memberEvent.StateKey())) if err != nil { continue } @@ -912,7 +903,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r stateKey := *memberEvent.StateKey() fledglingEvent := &gomatrixserverlib.ProtoEvent{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), Type: spec.MRoomMember, StateKey: &stateKey, SenderID: stateKey, @@ -928,12 +919,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r return err } - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - - signingIdentity, err := r.SigningIdentity(ctx, *validRoomID, *memberUserID) + signingIdentity, err := r.SigningIdentity(ctx, event.RoomID(), *memberUserID) if err != nil { return err } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 940783e03..ec03d6f13 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -197,7 +197,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // send the event asynchronously but we would need to ensure that 1) the events are written to the log in // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. - if err = u.api.OutputProducer.ProduceRoomEvents(u.event.RoomID(), updates); err != nil { + if err = u.api.OutputProducer.ProduceRoomEvents(u.event.RoomID().String(), updates); err != nil { return fmt.Errorf("u.api.WriteOutputEvents: %w", err) } @@ -290,7 +290,7 @@ func (u *latestEventsUpdater) latestState() error { if removed := len(u.removed) - len(u.added); !u.rewritesState && removed > 0 { logrus.WithFields(logrus.Fields{ "event_id": u.event.EventID(), - "room_id": u.event.RoomID(), + "room_id": u.event.RoomID().String(), "old_state_nid": u.oldStateNID, "new_state_nid": u.newStateNID, "old_latest": u.oldLatest.EventIDs(), diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index c46f8dba1..4cfc2cda9 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -139,11 +139,7 @@ func (r *Inputer) updateMembership( func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return isTargetLocalUser - } - userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey)) + userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey)) if err != nil || userID == nil { return isTargetLocalUser } @@ -168,7 +164,7 @@ func updateToJoinMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - RoomID: add.RoomID(), + RoomID: add.RoomID().String(), Membership: spec.Join, RetiredByEventID: add.EventID(), TargetSenderID: spec.SenderID(*add.StateKey()), @@ -195,7 +191,7 @@ func updateToLeaveMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - RoomID: add.RoomID(), + RoomID: add.RoomID().String(), Membership: newMembership, RetiredByEventID: add.EventID(), TargetSenderID: spec.SenderID(*add.StateKey()), diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 5b4c0727b..d9ab291e9 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -84,7 +84,7 @@ func (t *missingStateReq) processEventWithMissingState( // need to fallback to /state. t.log = util.GetLogger(ctx).WithFields(map[string]interface{}{ "txn_event": e.EventID(), - "room_id": e.RoomID(), + "room_id": e.RoomID().String(), "txn_prev_events": e.PrevEventIDs(), }) @@ -264,7 +264,7 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e // Look up what the state is after the backward extremity. This will either // come from the roomserver, if we know all the required events, or it will // come from a remote server via /state_ids if not. - prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID(), prevEventID) + prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID().String(), prevEventID) switch err2 := err.(type) { case gomatrixserverlib.EventValidationError: if !err2.Persistable { @@ -316,9 +316,9 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e } // There's more than one previous state - run them all through state res var err error - t.roomsMu.Lock(e.RoomID()) + t.roomsMu.Lock(e.RoomID().String()) resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e) - t.roomsMu.Unlock(e.RoomID()) + t.roomsMu.Unlock(e.RoomID().String()) switch err2 := err.(type) { case gomatrixserverlib.EventValidationError: if !err2.Persistable { @@ -510,7 +510,7 @@ retryAllowedState: }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: - h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) + h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID().String(), missing.AuthEventID, true) switch e := err2.(type) { case gomatrixserverlib.EventValidationError: if !e.Persistable { @@ -546,7 +546,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver trace, ctx := internal.StartRegion(ctx, "getMissingEvents") defer trace.EndRegion() - logger := t.log.WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) + logger := t.log.WithField("event_id", e.EventID()).WithField("room_id", e.RoomID().String()) latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) if err != nil { return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err) @@ -560,7 +560,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver var missingResp *fclient.RespMissingEvents for _, server := range t.servers { var m fclient.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), fclient.MissingEvents{ + if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID().String(), fclient.MissingEvents{ Limit: 20, // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. EarliestEvents: latestEvents, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 33200e819..dafa58736 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -301,7 +301,7 @@ func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent return ids, nil } if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") { - util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room") + util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID().String()).Info("Backfilled to the beginning of the room") b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{} return nil, nil } @@ -494,11 +494,7 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - continue - } - if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil { + if sender, err := b.querier.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { serverSet[sender.Domain()] = true } } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e07780d68..3abb69cb9 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -100,16 +100,12 @@ func (r *Inviter) ProcessInviteMembership( var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) - if err != nil { - return nil, err - } - userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) + userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) - if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { + if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID().String(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{ diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 5c63a6684..5bea00445 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -93,11 +93,21 @@ func (r *Leaver) performLeaveRoomByID( isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, *leaver) if err == nil && isInvitePending { sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser) - if serr != nil || sender == nil { - return nil, fmt.Errorf("sender %q has no matching userID", senderUser) + if serr != nil { + return nil, fmt.Errorf("failed looking up userID for sender %q: %w", senderUser, serr) } - if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { - return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, *leaver) + + var domain spec.ServerName + if sender == nil { + // TODO: Currently a federated invite has no way of knowing the mxid_mapping of the inviter. + // Should we add the inviter's m.room.member event (with mxid_mapping) to invite_room_state to allow + // the invited user to leave via the inviter's server? + domain = roomID.Domain() + } else { + domain = sender.Domain() + } + if !r.Cfg.Matrix.IsLocalServerName(domain) { + return r.performFederatedRejectInvite(ctx, req, res, domain, eventID, *leaver) } // check that this is not a "server notice room" accData := &userapi.QueryAccountDataResponse{} @@ -219,14 +229,14 @@ func (r *Leaver) performFederatedRejectInvite( ctx context.Context, req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam - inviteSender spec.UserID, eventID string, + inviteDomain spec.ServerName, eventID string, leaver spec.SenderID, ) ([]api.OutputEvent, error) { // Ask the federation sender to perform a federated leave for us. leaveReq := fsAPI.PerformLeaveRequest{ RoomID: req.RoomID, UserID: req.Leaver.String(), - ServerNames: []spec.ServerName{inviteSender.Domain()}, + ServerNames: []spec.ServerName{inviteDomain}, } leaveRes := fsAPI.PerformLeaveResponse{} if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { diff --git a/roomserver/internal/query/query_room_hierarchy.go b/roomserver/internal/query/query_room_hierarchy.go index 7274be520..76eba12be 100644 --- a/roomserver/internal/query/query_room_hierarchy.go +++ b/roomserver/internal/query/query_room_hierarchy.go @@ -513,14 +513,14 @@ func restrictedJoinRuleAllowedRooms(ctx context.Context, joinRuleEv *types.Heade } var jrContent gomatrixserverlib.JoinRuleContent if err := json.Unmarshal(joinRuleEv.Content(), &jrContent); err != nil { - util.GetLogger(ctx).Warnf("failed to check join_rule on room %s: %s", joinRuleEv.RoomID(), err) + util.GetLogger(ctx).Warnf("failed to check join_rule on room %s: %s", joinRuleEv.RoomID().String(), err) return nil } for _, allow := range jrContent.Allow { if allow.Type == spec.MRoomMembership { allowedRoomID, err := spec.NewRoomID(allow.RoomID) if err != nil { - util.GetLogger(ctx).Warnf("invalid room ID '%s' found in join_rule on room %s: %s", allow.RoomID, joinRuleEv.RoomID(), err) + util.GetLogger(ctx).Warnf("invalid room ID '%s' found in join_rule on room %s: %s", allow.RoomID, joinRuleEv.RoomID().String(), err) } else { allows = append(allows, *allowedRoomID) } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 619d93030..296960b2f 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -49,6 +49,7 @@ func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error { } builder := map[string]interface{}{ "event_id": eventID, + "room_id": "!room:a", "auth_events": authEvents, } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index b09c5afbd..3331c6029 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -696,8 +696,8 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) } - roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID()) - cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID()) + roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID().String()) + cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID().String()) // if we found both, the roomNID and version in our cache, no need to query the database if nidOK && versionOK { return &types.RoomInfo{ @@ -707,14 +707,14 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver } err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) + roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID().String(), roomVersion) if err != nil { return err } return nil }) if roomVersion != "" { - d.Cache.StoreRoomVersion(event.RoomID(), roomVersion) + d.Cache.StoreRoomVersion(event.RoomID().String(), roomVersion) } return &types.RoomInfo{ RoomVersion: roomVersion, @@ -1026,24 +1026,19 @@ func (d *EventDatabase) MaybeRedactEvent( case validated || redactedEvent == nil || redactionEvent == nil: // we've seen this redaction before or there is nothing to redact return nil - case redactedEvent.RoomID() != redactionEvent.RoomID(): + case redactedEvent.RoomID().String() != redactionEvent.RoomID().String(): // redactions across rooms aren't allowed ignoreRedaction = true return nil } - var validRoomID *spec.RoomID - validRoomID, err = spec.NewRoomID(redactedEvent.RoomID()) - if err != nil { - return err - } sender1Domain := "" - sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEvent.SenderID()) + sender1, err1 := querier.QueryUserIDForSender(ctx, redactedEvent.RoomID(), redactedEvent.SenderID()) if err1 == nil { sender1Domain = string(sender1.Domain()) } sender2Domain := "" - sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) + sender2, err2 := querier.QueryUserIDForSender(ctx, redactedEvent.RoomID(), redactionEvent.SenderID()) if err2 == nil { sender2Domain = string(sender2.Domain()) } @@ -1522,7 +1517,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } result[i] = tables.StrippedEvent{ EventType: ev.Type(), - RoomID: ev.RoomID(), + RoomID: ev.RoomID().String(), StateKey: *ev.StateKey(), ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}), } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 7f8e2de03..15811710d 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -271,7 +271,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) } if rc.req.RoomID == "" && event != nil { - rc.req.RoomID = event.RoomID() + rc.req.RoomID = event.RoomID().String() } if event == nil || !rc.authorisedToSeeEvent(event) { return nil, &util.JSONResponse{ @@ -526,7 +526,7 @@ func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool { // make sure the server is in this room var res fs.QueryJoinedHostServerNamesInRoomResponse err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), }, &res) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryJoinedHostServerNamesInRoom") @@ -545,7 +545,7 @@ func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool { // TODO: This does not honour m.room.create content var queryMembershipRes roomserver.QueryMembershipForUserResponse err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), UserID: rc.userID, }, &queryMembershipRes) if err != nil { @@ -612,7 +612,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *types.HeaderedEvent { // inject all the events into the roomserver then return the event in question rc.injectResponseToRoomserver(queryRes) for _, ev := range queryRes.ParsedEvents { - if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() { + if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID().String() { return &types.HeaderedEvent{PDU: ev} } } @@ -629,7 +629,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *types.HeaderedEvent { } } } - if rc.req.RoomID == event.RoomID() { + if rc.req.RoomID == event.RoomID().String() { return event } return nil diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go index 73bd6ed4f..ade2a1616 100644 --- a/setup/mscs/msc2836/storage.go +++ b/setup/mscs/msc2836/storage.go @@ -239,7 +239,7 @@ func (p *DB) StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error { return err } util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType) - _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID(), count, base64.RawStdEncoding.EncodeToString(hash), 0) + _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID().String(), count, base64.RawStdEncoding.EncodeToString(hash), 0) return err }) } diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 3ed455e9f..76b447133 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -113,7 +113,7 @@ func (s *OutputClientDataConsumer) Start() error { id = streamPos e := fulltext.IndexElement{ EventID: ev.EventID(), - RoomID: ev.RoomID(), + RoomID: ev.RoomID().String(), StreamPosition: streamPos, } e.SetContentType(ev.Type()) diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 1e87aee99..666f900d7 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" @@ -165,9 +166,9 @@ func (s *OutputRoomEventConsumer) onRedactEvent( return err } - if err = s.db.RedactRelations(ctx, msg.RedactedBecause.RoomID(), msg.RedactedEventID); err != nil { + if err = s.db.RedactRelations(ctx, msg.RedactedBecause.RoomID().String(), msg.RedactedEventID); err != nil { log.WithFields(log.Fields{ - "room_id": msg.RedactedBecause.RoomID(), + "room_id": msg.RedactedBecause.RoomID().String(), "event_id": msg.RedactedBecause.EventID(), "redacted_event_id": msg.RedactedEventID, }).WithError(err).Warn("Failed to redact relations") @@ -221,7 +222,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( // Finally, work out if there are any more events missing. if len(missingEventIDs) > 0 { eventsReq := &api.QueryEventsByIDRequest{ - RoomID: ev.RoomID(), + RoomID: ev.RoomID().String(), EventIDs: missingEventIDs, } eventsRes := &api.QueryEventsByIDResponse{} @@ -256,17 +257,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } if msg.RewritesState { - if err = s.db.PurgeRoomState(ctx, ev.RoomID()); err != nil { + if err = s.db.PurgeRoomState(ctx, ev.RoomID().String()); err != nil { return fmt.Errorf("s.db.PurgeRoom: %w", err) } } - validRoomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - return err - } - - userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, ev.SenderID()) + userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) if err != nil { return err } @@ -306,7 +302,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } s.pduStream.Advance(pduPos) - s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(ev, ev.RoomID().String(), nil, types.StreamingToken{PDUPosition: pduPos}) return nil } @@ -323,12 +319,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( // old events in the sync API, this should at least prevent us // from confusing clients into thinking they've joined/left rooms. - validRoomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - return err - } - - userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, ev.SenderID()) + userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) if err != nil { return err } @@ -354,7 +345,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( if err = s.db.UpdateRelations(ctx, ev); err != nil { log.WithFields(log.Fields{ - "room_id": ev.RoomID(), + "room_id": ev.RoomID().String(), "event_id": ev.EventID(), "type": ev.Type(), }).WithError(err).Warn("Failed to update relations") @@ -367,7 +358,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( } s.pduStream.Advance(pduPos) - s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(ev, ev.RoomID().String(), nil, types.StreamingToken{PDUPosition: pduPos}) return nil } @@ -387,11 +378,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst return sp, fmt.Errorf("unexpected nil state_key") } - validRoomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - return sp, err - } - userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey())) + userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) if err != nil || userID == nil { return sp, fmt.Errorf("failed getting userID for sender: %w", err) } @@ -400,7 +387,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst } // cancel any peeks for it - peekSP, peekErr := s.db.DeletePeeks(ctx, ev.RoomID(), *ev.StateKey()) + peekSP, peekErr := s.db.DeletePeeks(ctx, ev.RoomID().String(), *ev.StateKey()) if peekErr != nil { return sp, fmt.Errorf("s.db.DeletePeeks: %w", peekErr) } @@ -418,11 +405,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( return } - validRoomID, err := spec.NewRoomID(msg.Event.RoomID()) - if err != nil { - return - } - userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey())) + userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey())) if err != nil || userID == nil { return } @@ -559,15 +542,10 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return event, err - } - sKeyUser := "" if stateKey != "" { var sku *spec.UserID - sku, err = s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, spec.SenderID(stateKey)) + sku, err = s.rsAPI.QueryUserIDForSender(s.ctx, event.RoomID(), spec.SenderID(stateKey)) if err == nil && sku != nil { sKeyUser = sku.String() event.StateKeyResolved = &sKeyUser @@ -575,13 +553,13 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) } prevEvent, err := snapshot.GetStateEvent( - s.ctx, event.RoomID(), event.Type(), sKeyUser, + s.ctx, event.RoomID().String(), event.Type(), sKeyUser, ) if err != nil { return event, err } - userID, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, event.SenderID()) + userID, err := s.rsAPI.QueryUserIDForSender(s.ctx, event.RoomID(), event.SenderID()) if err != nil { return event, err } @@ -592,16 +570,10 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) return event, nil } - prevEventSender := string(prevEvent.SenderID()) - prevUser, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, prevEvent.SenderID()) - if err == nil && prevUser != nil { - prevEventSender = prevUser.String() - } - - prev := types.PrevEventRef{ + prev := synctypes.PrevEventRef{ PrevContent: prevEvent.Content(), ReplacesState: prevEvent.EventID(), - PrevSenderID: prevEventSender, + PrevSenderID: string(prevEvent.SenderID()), } event.PDU, err = event.SetUnsigned(prev) @@ -615,7 +587,7 @@ func (s *OutputRoomEventConsumer) writeFTS(ev *rstypes.HeaderedEvent, pduPositio } e := fulltext.IndexElement{ EventID: ev.EventID(), - RoomID: ev.RoomID(), + RoomID: ev.RoomID().String(), StreamPosition: int64(pduPosition), } e.SetContentType(ev.Type()) diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 7aae9fd38..48475327d 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -118,26 +118,23 @@ func ApplyHistoryVisibilityFilter( start := time.Now() // try to get the current membership of the user - membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID.String(), math.MaxInt64) + membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID().String(), userID.String(), math.MaxInt64) if err != nil { return nil, err } // Get the mapping from eventID -> eventVisibility eventsFiltered := make([]*types.HeaderedEvent, 0, len(events)) - firstEvRoomID, err := spec.NewRoomID(events[0].RoomID()) + firstEvRoomID := events[0].RoomID() + senderID, err := rsAPI.QuerySenderIDForUser(ctx, firstEvRoomID, userID) if err != nil { return nil, err } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, *firstEvRoomID, userID) - if err != nil { - return nil, err - } - visibilities := visibilityForEvents(ctx, rsAPI, events, senderID, *firstEvRoomID) + visibilities := visibilityForEvents(ctx, rsAPI, events, senderID, firstEvRoomID) for _, ev := range events { // Validate same room assumption - if ev.RoomID() != firstEvRoomID.String() { + if ev.RoomID().String() != firstEvRoomID.String() { return nil, fmt.Errorf("events from different rooms supplied to ApplyHistoryVisibilityFilter") } diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index a8733f6fe..07b80b165 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -101,20 +101,13 @@ func (n *Notifier) OnNewEvent( n._removeEmptyUserStreams() if ev != nil { - validRoomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - log.WithError(err).WithField("event_id", ev.EventID()).Errorf( - "Notifier.OnNewEvent: RoomID is invalid", - ) - return - } // Map this event's room_id to a list of joined users, and wake them up. - usersToNotify := n._joinedUsers(ev.RoomID()) + usersToNotify := n._joinedUsers(ev.RoomID().String()) // Map this event's room_id to a list of peeking devices, and wake them up. - peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) + peekingDevicesToNotify := n._peekingDevices(ev.RoomID().String()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { - targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey())) + targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey())) if err != nil || targetUserID == nil { log.WithError(err).WithField("event_id", ev.EventID()).Errorf( "Notifier.OnNewEvent: Failed to find the userID for this event", @@ -134,11 +127,11 @@ func (n *Notifier) OnNewEvent( // Manually append the new user's ID so they get notified // along all members in the room usersToNotify = append(usersToNotify, targetUserID.String()) - n._addJoinedUser(ev.RoomID(), targetUserID.String()) + n._addJoinedUser(ev.RoomID().String(), targetUserID.String()) case spec.Leave: fallthrough case spec.Ban: - n._removeJoinedUser(ev.RoomID(), targetUserID.String()) + n._removeJoinedUser(ev.RoomID().String(), targetUserID.String()) } } } diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 4fa282f3b..c089539f0 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -118,32 +118,19 @@ func GetEvent( } } - senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, events[0].SenderID()) - if err != nil || senderUserID == nil { - util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("QueryUserIDForSender errored or returned nil-user ID when user should be part of a room") + clientEvent, err := synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if err != nil { + util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.Unknown("internal server error"), } } - sk := events[0].StateKey() - if sk != nil && *sk != "" { - evRoomID, err := spec.NewRoomID(events[0].RoomID()) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("roomID is invalid"), - } - } - skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*events[0].StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, *senderUserID, sk), + JSON: *clientEvent, } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 5e5d0125f..e849adf6d 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -152,15 +152,7 @@ func GetMemberships( } } - validRoomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID()) + userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) if err != nil || userID == nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") return util.JSONResponse{ diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index e3d1069a0..935ba83b3 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -130,23 +130,16 @@ func Relations( // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } + clientEvent, err := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }) + if err != nil { + util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent") + continue } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk), + *clientEvent, ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index d892b604a..4a8be9f49 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -205,12 +205,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - validRoomID, roomErr := spec.NewRoomID(ev.RoomID()) - if err != nil { - logrus.WithError(roomErr).WithField("room_id", ev.RoomID()).Warn("failed to query userprofile") - continue - } - userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID()) + userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) if queryErr != nil { logrus.WithError(queryErr).WithField("sender_id", ev.SenderID()).Warn("failed to query userprofile") continue @@ -218,7 +213,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profile, ok := knownUsersProfiles[userID.String()] if !ok { - stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, string(ev.SenderID())) + stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID().String(), spec.MRoomMember, string(ev.SenderID())) if stateErr != nil { logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue @@ -235,25 +230,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos[userID.String()] = profile } - sender := spec.UserID{} - validRoomID, roomErr := spec.NewRoomID(event.RoomID()) + clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { - logrus.WithError(roomErr).WithField("room_id", event.RoomID()).Warn("failed to query userprofile") + util.GetLogger(req.Context()).WithError(err).WithField("senderID", event.SenderID()).Error("Failed converting to ClientEvent") continue } - userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } results = append(results, Result{ Context: SearchContextResponse{ Start: startToken.String(), @@ -267,14 +251,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk), + Result: *clientEvent, }) - roomGroup := groups[event.RoomID()] + roomGroup := groups[event.RoomID().String()] roomGroup.Results = append(roomGroup.Results, event.EventID()) - groups[event.RoomID()] = roomGroup - if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok { + groups[event.RoomID().String()] = roomGroup + if _, ok := stateForRooms[event.RoomID().String()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok { stateFilter := synctypes.DefaultStateFilter() - state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil) + state, err := snapshot.CurrentState(ctx, event.RoomID().String(), &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to get current state") return util.JSONResponse{ @@ -282,7 +266,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + stateForRooms[event.RoomID().String()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }) } @@ -328,19 +312,19 @@ func contextEvents( roomFilter *synctypes.RoomEventFilter, searchReq SearchRequest, ) ([]*types.HeaderedEvent, []*types.HeaderedEvent, error) { - id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID()) + id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID().String(), event.EventID()) if err != nil { logrus.WithError(err).Error("failed to query context event") return nil, nil, err } roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit - eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter) + eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, event.RoomID().String(), roomFilter) if err != nil { logrus.WithError(err).Error("failed to query before context event") return nil, nil, err } roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit - _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter) + _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, event.RoomID().String(), roomFilter) if err != nil { logrus.WithError(err).Error("failed to query after context event") return nil, nil, err diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 905a9a1ac..a983bb7b5 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -238,7 +238,7 @@ func TestSearch(t *testing.T) { } elements = append(elements, fulltext.IndexElement{ EventID: x.EventID(), - RoomID: x.RoomID(), + RoomID: x.RoomID().String(), Content: string(x.Content()), ContentType: x.Type(), StreamPosition: int64(sp), diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 112fa9d4a..b0148bef5 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -340,7 +340,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) _, err = stmt.ExecContext( ctx, - event.RoomID(), + event.RoomID().String(), event.EventID(), event.Type(), event.UserID.String(), diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index 7b8d2d733..1f46cd09d 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -99,7 +99,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( err = sqlutil.TxStmt(txn, s.insertInviteEventStmt).QueryRowContext( ctx, - inviteEvent.RoomID(), + inviteEvent.RoomID().String(), inviteEvent.EventID(), inviteEvent.UserID.String(), headeredJSON, diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 09b47432b..fcbe14b16 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -108,7 +108,7 @@ func (s *membershipsStatements) UpsertMembership( } _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( ctx, - event.RoomID(), + event.RoomID().String(), event.StateKeyResolved, membership, event.EventID(), diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index b58cf59f0..b2d191111 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -334,7 +334,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, nil, err } - needSet := stateNeeded[ev.RoomID()] + needSet := stateNeeded[ev.RoomID().String()] if needSet == nil { // make set if required needSet = make(map[string]bool) } @@ -344,7 +344,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( for _, id := range addIDs { needSet[id] = true } - stateNeeded[ev.RoomID()] = needSet + stateNeeded[ev.RoomID().String()] = needSet ev.Visibility = historyVisibility eventIDToEvent[eventID] = types.StreamEvent{ @@ -403,7 +403,7 @@ func (s *outputRoomEventsStatements) InsertEvent( stmt := sqlutil.TxStmt(txn, s.insertEventStmt) err = stmt.QueryRowContext( ctx, - event.RoomID(), + event.RoomID().String(), event.EventID(), headeredJSON, event.Type(), diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index b281f3300..2158d99ec 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -107,7 +107,7 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition, ) (topoPos types.StreamPosition, err error) { err = sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).QueryRowContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + ctx, event.EventID(), event.Depth(), event.RoomID().String(), pos, ).Scan(&topoPos) return } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 69e64cc79..0f4080d53 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -114,14 +114,7 @@ func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Dev }).WithError(err).Warnf("Failed to add transaction ID to event") continue } - roomID, err := spec.NewRoomID(in[i].RoomID()) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Room ID is invalid") - continue - } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) + deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID) if err != nil || deviceSenderID == nil { logrus.WithFields(logrus.Fields{ "event_id": out[i].EventID(), @@ -236,7 +229,7 @@ func (d *Database) UpsertAccountData( // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. // This function should always be called within a sqlutil.Writer for safety in SQLite. func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *rstypes.HeaderedEvent) error { - if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { + if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID().String(), ev.EventID()); err != nil { return err } @@ -257,7 +250,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e // If the event is missing, consider it a backward extremity. if !found { - if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { + if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID().String(), ev.EventID(), eID); err != nil { return err } } @@ -426,7 +419,7 @@ func (d *Database) fetchStateEvents( } // we know we got them all otherwise an error would've been returned, so just loop the events for _, ev := range evs { - roomID := ev.RoomID() + roomID := ev.RoomID().String() stateBetween[roomID] = append(stateBetween[roomID], ev) } } @@ -522,11 +515,7 @@ func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userI if err != nil { return "", "" } - roomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - return "", "" - } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *fullUser) + senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser) if err != nil || senderID == nil { return "", "" } @@ -626,7 +615,7 @@ func (d *Database) UpdateRelations(ctx context.Context, event *rstypes.HeaderedE default: return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Relations.InsertRelation( - ctx, txn, event.RoomID(), content.Relations.EventID, + ctx, txn, event.RoomID().String(), content.Relations.EventID, event.EventID(), event.Type(), content.Relations.RelationType, ) }) diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 3bd19b367..78b2e397c 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -339,7 +339,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) _, err = stmt.ExecContext( ctx, - event.RoomID(), + event.RoomID().String(), event.EventID(), event.Type(), event.UserID.String(), diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 7e0d895f1..ebb469d24 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -106,7 +106,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( _, err = stmt.ExecContext( ctx, streamPos, - inviteEvent.RoomID(), + inviteEvent.RoomID().String(), inviteEvent.EventID(), inviteEvent.UserID.String(), headeredJSON, diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index a9e880d2a..05f756fda 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -111,7 +111,7 @@ func (s *membershipsStatements) UpsertMembership( } _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( ctx, - event.RoomID(), + event.RoomID().String(), event.StateKeyResolved, membership, event.EventID(), diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 06c65419a..93caee806 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -254,7 +254,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, nil, err } - needSet := stateNeeded[ev.RoomID()] + needSet := stateNeeded[ev.RoomID().String()] if needSet == nil { // make set if required needSet = make(map[string]bool) } @@ -264,7 +264,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( for _, id := range addIDs { needSet[id] = true } - stateNeeded[ev.RoomID()] = needSet + stateNeeded[ev.RoomID().String()] = needSet ev.Visibility = historyVisibility eventIDToEvent[eventID] = types.StreamEvent{ @@ -344,7 +344,7 @@ func (s *outputRoomEventsStatements) InsertEvent( _, err = insertStmt.ExecContext( ctx, streamPos, - event.RoomID(), + event.RoomID().String(), event.EventID(), headeredJSON, event.Type(), diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 614e1df9e..36967d1e7 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -106,7 +106,7 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition, ) (types.StreamPosition, error) { _, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + ctx, event.EventID(), event.Depth(), event.RoomID().String(), pos, ) return types.StreamPosition(event.Depth()), err } diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 7c29d84ae..a3634c03f 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -63,31 +63,27 @@ func (p *InviteStreamProvider) IncrementalSync( return from } + eventFormat := synctypes.FormatSync + if req.Filter.EventFormat == synctypes.EventFormatFederation { + eventFormat = synctypes.FormatSyncFederation + } + for roomID, inviteEvent := range invites { user := spec.UserID{} - validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) - if err != nil { - continue - } - sender, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, inviteEvent.SenderID()) + sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) if err == nil && sender != nil { user = *sender } - sk := inviteEvent.StateKey() - if sk != nil && *sk != "" { - skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } - // skip ignored user events if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, user, sk) + ir, err := types.NewInviteResponse(ctx, p.rsAPI, inviteEvent, eventFormat) + if err != nil { + req.Log.WithError(err).Error("failed creating invite response") + continue + } req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 4622c21ad..3abb0b3c6 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -3,7 +3,6 @@ package streams import ( "context" "database/sql" - "encoding/json" "fmt" "time" @@ -16,8 +15,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib/spec" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" @@ -88,6 +85,11 @@ func (p *PDUStreamProvider) CompleteSync( req.Log.WithError(err).Error("unable to update event filter with ignored users") } + eventFormat := synctypes.FormatSync + if req.Filter.EventFormat == synctypes.EventFormatFederation { + eventFormat = synctypes.FormatSyncFederation + } + recentEvents, err := snapshot.RecentEvents(ctx, joinedRoomIDs, r, &eventFilter, true, true) if err != nil { return from @@ -105,7 +107,7 @@ func (p *PDUStreamProvider) CompleteSync( // get the join response for each room jr, jerr := p.getJoinResponseForCompleteSync( ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, false, - events.Events, events.Limited, + events.Events, events.Limited, eventFormat, ) if jerr != nil { req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") @@ -142,7 +144,7 @@ func (p *PDUStreamProvider) CompleteSync( events := recentEvents[roomID] jr, err = p.getJoinResponseForCompleteSync( ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, true, - events.Events, events.Limited, + events.Events, events.Limited, eventFormat, ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") @@ -346,26 +348,14 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err) } + eventFormat := synctypes.FormatSync + if req.Filter.EventFormat == synctypes.EventFormatFederation { + eventFormat = synctypes.FormatSyncFederation + } + // Now that we've filtered the timeline, work out which state events are still // left. Anything that appears in the filtered timeline will be removed from the // "state" section and kept in "timeline". - - // update the powerlevel event for timeline events - for i, ev := range events { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - var newEvent gomatrixserverlib.PDU - newEvent, err = p.updatePowerLevelEvent(ctx, ev) - if err != nil { - return r.From, err - } - events[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)), gomatrixserverlib.TopologicalOrderByAuthEvents, @@ -380,15 +370,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( continue } delta.StateEvents[i-skipped] = he - // update the powerlevel event for state events - if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") { - var newEvent gomatrixserverlib.PDU - newEvent, err = p.updatePowerLevelEvent(ctx, he) - if err != nil { - return r.From, err - } - delta.StateEvents[i-skipped] = &rstypes.HeaderedEvent{PDU: newEvent} - } } delta.StateEvents = delta.StateEvents[:len(sEvents)-skipped] @@ -413,13 +394,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Join[delta.RoomID] = jr @@ -428,11 +409,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Peek[delta.RoomID] = jr @@ -443,13 +424,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Leave[delta.RoomID] = lr @@ -458,75 +439,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( return latestPosition, nil } -func (p *PDUStreamProvider) updatePowerLevelEvent(ctx context.Context, ev *rstypes.HeaderedEvent) (gomatrixserverlib.PDU, error) { - pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev) - if err != nil { - return nil, err - } - newPls := make(map[string]int64) - var userID *spec.UserID - for user, level := range pls.Users { - validRoomID, _ := spec.NewRoomID(ev.RoomID()) - userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user)) - if err != nil { - return nil, err - } - newPls[userID.String()] = level - } - var newPlBytes, newEv []byte - newPlBytes, err = json.Marshal(newPls) - if err != nil { - return nil, err - } - newEv, err = sjson.SetRawBytes(ev.JSON(), "content.users", newPlBytes) - if err != nil { - return nil, err - } - - // do the same for prev content - prevContent := gjson.GetBytes(ev.JSON(), "unsigned.prev_content") - if !prevContent.Exists() { - var evNew gomatrixserverlib.PDU - evNew, err = gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON(newEv, false) - if err != nil { - return nil, err - } - - return evNew, err - } - pls = gomatrixserverlib.PowerLevelContent{} - err = json.Unmarshal([]byte(prevContent.Raw), &pls) - if err != nil { - return nil, err - } - - newPls = make(map[string]int64) - for user, level := range pls.Users { - validRoomID, _ := spec.NewRoomID(ev.RoomID()) - userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user)) - if err != nil { - return nil, err - } - newPls[userID.String()] = level - } - newPlBytes, err = json.Marshal(newPls) - if err != nil { - return nil, err - } - newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes) - if err != nil { - return nil, err - } - - var evNew gomatrixserverlib.PDU - evNew, err = gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false) - if err != nil { - return nil, err - } - - return evNew, err -} - // applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make // sure we always return the required events in the timeline. func applyHistoryVisibilityFilter( @@ -592,6 +504,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( isPeek bool, recentStreamEvents []types.StreamEvent, limited bool, + eventFormat synctypes.ClientEventFormat, ) (jr *types.JoinResponse, err error) { jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. @@ -675,43 +588,14 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } - // Update powerlevel events for timeline events - for i, ev := range events { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - newEvent, err := p.updatePowerLevelEvent(ctx, ev) - if err != nil { - return nil, err - } - events[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - // Update powerlevel events for state events - for i, ev := range stateEvents { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - newEvent, err := p.updatePowerLevelEvent(ctx, ev) - if err != nil { - return nil, err - } - stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) return jr, nil diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index ea1183cd2..ac5268511 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -209,6 +209,156 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { } } +func TestSyncAPIEventFormatPowerLevels(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testSyncEventFormatPowerLevels(t, dbType) + }) +} + +func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) { + user := test.NewUser(t) + setRoomVersion := func(t *testing.T, r *test.Room) { r.Version = gomatrixserverlib.RoomVersionPseudoIDs } + room := test.NewRoom(t, user, setRoomVersion) + alice := userapi.Device{ + ID: "ALICEID", + UserID: user.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + room.CreateAndInsert(t, user, spec.MRoomPowerLevels, gomatrixserverlib.PowerLevelContent{ + Users: map[string]int64{ + user.ID: 100, + }, + }, test.WithStateKey("")) + + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) + msgs := toNATSMsgs(t, cfg, room.Events()...) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + testrig.MustPublishMsgs(t, jsctx, msgs...) + + testCases := []struct { + name string + wantCode int + wantJoinedRooms []string + eventFormat synctypes.ClientEventFormat + }{ + { + name: "Client format", + wantCode: 200, + wantJoinedRooms: []string{room.ID}, + eventFormat: synctypes.FormatSync, + }, + { + name: "Federation format", + wantCode: 200, + wantJoinedRooms: []string{room.ID}, + eventFormat: synctypes.FormatSyncFederation, + }, + } + + syncUntil(t, routers, alice.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) + return gjson.Get(syncBody, path).Exists() + }) + + for _, tc := range testCases { + format := "" + if tc.eventFormat == synctypes.FormatSyncFederation { + format = "federation" + } + + w := httptest.NewRecorder() + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + "access_token": alice.AccessToken, + "timeout": "0", + "filter": fmt.Sprintf(`{"event_format":"%s"}`, format), + }))) + if w.Code != tc.wantCode { + t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) + } + if tc.wantJoinedRooms != nil { + var res types.Response + if err := json.NewDecoder(w.Body).Decode(&res); err != nil { + t.Fatalf("%s: failed to decode response body: %s", tc.name, err) + } + if len(res.Rooms.Join) != len(tc.wantJoinedRooms) { + t.Errorf("%s: got %v joined rooms, want %v.\nResponse: %+v", tc.name, len(res.Rooms.Join), len(tc.wantJoinedRooms), res) + } + t.Logf("res: %+v", res.Rooms.Join[room.ID]) + + gotEventIDs := make([]string, len(res.Rooms.Join[room.ID].Timeline.Events)) + for i, ev := range res.Rooms.Join[room.ID].Timeline.Events { + gotEventIDs[i] = ev.EventID + } + test.AssertEventIDsEqual(t, gotEventIDs, room.Events()) + + event := room.CreateAndInsert(t, user, spec.MRoomPowerLevels, gomatrixserverlib.PowerLevelContent{ + Users: map[string]int64{ + user.ID: 100, + "@otheruser:localhost": 50, + }, + }, test.WithStateKey("")) + + msgs := toNATSMsgs(t, cfg, event) + testrig.MustPublishMsgs(t, jsctx, msgs...) + + syncUntil(t, routers, alice.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) + return gjson.Get(syncBody, path).Exists() + }) + + since := res.NextBatch.String() + w := httptest.NewRecorder() + routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + "access_token": alice.AccessToken, + "timeout": "0", + "filter": fmt.Sprintf(`{"event_format":"%s"}`, format), + "since": since, + }))) + if w.Code != 200 { + t.Errorf("since=%s got HTTP %d want 200", since, w.Code) + } + + res = *types.NewResponse() + if err := json.NewDecoder(w.Body).Decode(&res); err != nil { + t.Errorf("failed to decode response body: %s", err) + } + if len(res.Rooms.Join) != 1 { + t.Fatalf("since=%s got %d joined rooms, want 1", since, len(res.Rooms.Join)) + } + gotEventIDs = make([]string, len(res.Rooms.Join[room.ID].Timeline.Events)) + for j, ev := range res.Rooms.Join[room.ID].Timeline.Events { + gotEventIDs[j] = ev.EventID + if ev.Type == spec.MRoomPowerLevels { + content := gomatrixserverlib.PowerLevelContent{} + err := json.Unmarshal(ev.Content, &content) + if err != nil { + t.Errorf("failed to unmarshal power level content: %s", err) + } + otherUserLevel := content.UserLevel("@otheruser:localhost") + if otherUserLevel != 50 { + t.Errorf("Expected user PL of %d but got %d", 50, otherUserLevel) + } + } + } + events := []*rstypes.HeaderedEvent{room.Events()[len(room.Events())-1]} + test.AssertEventIDsEqual(t, gotEventIDs, events) + } + } +} + // Tests what happens when we create a room and then /sync before all events from /createRoom have // been sent to the syncapi func TestSyncAPICreateRoomSyncEarly(t *testing.T) { @@ -1251,7 +1401,7 @@ func toNATSMsgs(t *testing.T, cfg *config.Dendrite, input ...*rstypes.HeaderedEv if ev.StateKey() != nil { addsStateIDs = append(addsStateIDs, ev.EventID()) } - result[i] = testrig.NewOutputEventMsg(t, cfg, ev.RoomID(), api.OutputEvent{ + result[i] = testrig.NewOutputEventMsg(t, cfg, ev.RoomID().String(), api.OutputEvent{ Type: rsapi.OutputTypeNewRoomEvent, NewRoomEvent: &rsapi.OutputNewRoomEvent{ Event: ev, diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index a78aea1c6..fe4f6c07f 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -16,12 +16,23 @@ package synctypes import ( + "encoding/json" "fmt" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) +// PrevEventRef represents a reference to a previous event in a state event upgrade +type PrevEventRef struct { + PrevContent json.RawMessage `json:"prev_content"` + ReplacesState string `json:"replaces_state"` + PrevSenderID string `json:"prev_sender"` +} + type ClientEventFormat int const ( @@ -30,8 +41,21 @@ const ( // FormatSync will include only the event keys required by the /sync API. Notably, this // means the 'room_id' will be missing from the events. FormatSync + // FormatSyncFederation will include all event keys normally included in federated events. + // This allows clients to request federated formatted events via the /sync API. + FormatSyncFederation ) +// ClientFederationFields extends a ClientEvent to contain the additional fields present in a +// federation event. Used when the client requests `event_format` of type `federation`. +type ClientFederationFields struct { + Depth int64 `json:"depth,omitempty"` + PrevEvents []string `json:"prev_events,omitempty"` + AuthEvents []string `json:"auth_events,omitempty"` + Signatures spec.RawJSON `json:"signatures,omitempty"` + Hashes spec.RawJSON `json:"hashes,omitempty"` +} + // ClientEvent is an event which is fit for consumption by clients, in accordance with the specification. type ClientEvent struct { Content spec.RawJSON `json:"content"` @@ -44,6 +68,9 @@ type ClientEvent struct { Type string `json:"type"` Unsigned spec.RawJSON `json:"unsigned,omitempty"` Redacts string `json:"redacts,omitempty"` + + // Only sent to clients when `event_format` == `federation`. + ClientFederationFields } // ToClientEvents converts server events to client events. @@ -53,72 +80,24 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, if se == nil { continue // TODO: shouldn't happen? } - sender := spec.UserID{} - validRoomID, err := spec.NewRoomID(se.RoomID()) + ev, err := ToClientEvent(se, format, userIDForSender) if err != nil { + logrus.WithError(err).Warn("Failed converting event to ClientEvent") continue } - userID, err := userIDForSender(*validRoomID, se.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := se.StateKey() - if sk != nil && *sk != "" { - skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk)) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } - evs = append(evs, ToClientEvent(se, format, sender, sk)) + evs = append(evs, *ev) } return evs } -// ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent { - ce := ClientEvent{ - Content: spec.RawJSON(se.Content()), - Sender: sender.String(), - Type: se.Type(), - StateKey: stateKey, - Unsigned: spec.RawJSON(se.Unsigned()), - OriginServerTS: se.OriginServerTS(), - EventID: se.EventID(), - Redacts: se.Redacts(), - } - if format == FormatAll { - ce.RoomID = se.RoomID() - } - if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { - ce.SenderKey = se.SenderID() - } - return ce -} - -// ToClientEvent converts a single server event to a client event. +// ToClientEventDefault converts a single server event to a client event. // It provides default logic for event.SenderID & event.StateKey -> userID conversions. func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { - sender := spec.UserID{} - validRoomID, err := spec.NewRoomID(event.RoomID()) + ev, err := ToClientEvent(event, FormatAll, userIDQuery) if err != nil { return ClientEvent{} } - userID, err := userIDQuery(*validRoomID, event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, err := userIDQuery(*validRoomID, spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } - return ToClientEvent(event, FormatAll, sender, sk) + return *ev } // If provided state key is a user ID (state keys beginning with @ are reserved for this purpose) @@ -132,11 +111,11 @@ func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec. parsedStateKey, err := spec.NewUserID(stateKey, true) if err != nil { // If invalid user ID, then there is no associated state event. - return nil, fmt.Errorf("Provided state key begins with @ but is not a valid user ID: %s", err.Error()) + return nil, fmt.Errorf("Provided state key begins with @ but is not a valid user ID: %w", err) } senderID, err := senderIDQuery(roomID, *parsedStateKey) if err != nil { - return nil, fmt.Errorf("Failed to query sender ID: %s", err.Error()) + return nil, fmt.Errorf("Failed to query sender ID: %w", err) } if senderID == nil { // If no sender ID, then there is no associated state event. @@ -148,3 +127,304 @@ func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec. return &stateKey, nil } } + +// ToClientEvent converts a single server event to a client event. +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) (*ClientEvent, error) { + ce := ClientEvent{ + Content: se.Content(), + Sender: string(se.SenderID()), + Type: se.Type(), + StateKey: se.StateKey(), + Unsigned: se.Unsigned(), + OriginServerTS: se.OriginServerTS(), + EventID: se.EventID(), + Redacts: se.Redacts(), + } + + switch format { + case FormatAll: + ce.RoomID = se.RoomID().String() + case FormatSync: + case FormatSyncFederation: + ce.RoomID = se.RoomID().String() + ce.AuthEvents = se.AuthEventIDs() + ce.PrevEvents = se.PrevEventIDs() + ce.Depth = se.Depth() + // TODO: Set Signatures & Hashes fields + } + + if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + err := updatePseudoIDs(&ce, se, userIDForSender, format) + if err != nil { + return nil, err + } + } + + return &ce, nil +} + +func updatePseudoIDs(ce *ClientEvent, se gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender, format ClientEventFormat) error { + ce.SenderKey = se.SenderID() + + userID, err := userIDForSender(se.RoomID(), se.SenderID()) + if err == nil && userID != nil { + ce.Sender = userID.String() + } + + sk := se.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) + if err == nil && skUserID != nil { + skString := skUserID.String() + ce.StateKey = &skString + } + } + + var prev PrevEventRef + if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" { + prevUserID, err := userIDForSender(se.RoomID(), spec.SenderID(prev.PrevSenderID)) + if err == nil && userID != nil { + prev.PrevSenderID = prevUserID.String() + } else { + errString := "userID unknown" + if err != nil { + errString = err.Error() + } + logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString) + // NOTE: Not much can be done here, so leave the previous value in place. + } + ce.Unsigned, err = json.Marshal(prev) + if err != nil { + err = fmt.Errorf("Failed to marshal unsigned content for ClientEvent: %w", err) + return err + } + } + + switch se.Type() { + case spec.MRoomCreate: + updatedContent, err := updateCreateEvent(se.Content(), userIDForSender, se.RoomID()) + if err != nil { + err = fmt.Errorf("Failed to update m.room.create event for ClientEvent: %w", err) + return err + } + ce.Content = updatedContent + case spec.MRoomMember: + updatedEvent, err := updateInviteEvent(userIDForSender, se, format) + if err != nil { + err = fmt.Errorf("Failed to update m.room.member event for ClientEvent: %w", err) + return err + } + if updatedEvent != nil { + ce.Unsigned = updatedEvent.Unsigned() + } + case spec.MRoomPowerLevels: + updatedEvent, err := updatePowerLevelEvent(userIDForSender, se, format) + if err != nil { + err = fmt.Errorf("Failed update m.room.power_levels event for ClientEvent: %w", err) + return err + } + if updatedEvent != nil { + ce.Content = updatedEvent.Content() + ce.Unsigned = updatedEvent.Unsigned() + } + } + + return nil +} + +func updateCreateEvent(content spec.RawJSON, userIDForSender spec.UserIDForSender, roomID spec.RoomID) (spec.RawJSON, error) { + if creator := gjson.GetBytes(content, "creator"); creator.Exists() { + oldCreator := creator.Str + userID, err := userIDForSender(roomID, spec.SenderID(oldCreator)) + if err != nil { + err = fmt.Errorf("Failed to find userID for creator in ClientEvent: %w", err) + return nil, err + } + + if userID != nil { + var newCreatorBytes, newContent []byte + newCreatorBytes, err = json.Marshal(userID.String()) + if err != nil { + err = fmt.Errorf("Failed to marshal new creator for ClientEvent: %w", err) + return nil, err + } + + newContent, err = sjson.SetRawBytes([]byte(content), "creator", newCreatorBytes) + if err != nil { + err = fmt.Errorf("Failed to set new creator for ClientEvent: %w", err) + return nil, err + } + + return newContent, nil + } + } + + return content, nil +} + +func updateInviteEvent(userIDForSender spec.UserIDForSender, ev gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) { + if inviteRoomState := gjson.GetBytes(ev.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { + userID, err := userIDForSender(ev.RoomID(), ev.SenderID()) + if err != nil || userID == nil { + if err != nil { + err = fmt.Errorf("invalid userID found when updating invite_room_state: %w", err) + } + return nil, err + } + + newState, err := GetUpdatedInviteRoomState(userIDForSender, inviteRoomState, ev, ev.RoomID(), eventFormat) + if err != nil { + return nil, err + } + + var newEv []byte + newEv, err = sjson.SetRawBytes(ev.JSON(), "unsigned.invite_room_state", newState) + if err != nil { + return nil, err + } + + return gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSON(newEv, false) + } + + return ev, nil +} + +type InviteRoomStateEvent struct { + Content spec.RawJSON `json:"content"` + SenderID string `json:"sender"` + StateKey *string `json:"state_key"` + Type string `json:"type"` +} + +func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomState gjson.Result, event gomatrixserverlib.PDU, roomID spec.RoomID, eventFormat ClientEventFormat) (spec.RawJSON, error) { + var res spec.RawJSON + inviteStateEvents := []InviteRoomStateEvent{} + err := json.Unmarshal([]byte(inviteRoomState.Raw), &inviteStateEvents) + if err != nil { + return nil, err + } + + if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation { + for i, ev := range inviteStateEvents { + userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID)) + if userIDErr != nil { + return nil, userIDErr + } + if userID != nil { + inviteStateEvents[i].SenderID = userID.String() + } + + if ev.StateKey != nil && *ev.StateKey != "" { + userID, senderErr := userIDForSender(roomID, spec.SenderID(*ev.StateKey)) + if senderErr != nil { + return nil, senderErr + } + if userID != nil { + user := userID.String() + inviteStateEvents[i].StateKey = &user + } + } + + updatedContent, updateErr := updateCreateEvent(ev.Content, userIDForSender, roomID) + if updateErr != nil { + updateErr = fmt.Errorf("Failed to update m.room.create event for ClientEvent: %w", userIDErr) + return nil, updateErr + } + inviteStateEvents[i].Content = updatedContent + } + } + + res, err = json.Marshal(inviteStateEvents) + if err != nil { + return nil, err + } + + return res, nil +} + +func updatePowerLevelEvent(userIDForSender spec.UserIDForSender, se gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) { + if !se.StateKeyEquals("") { + return se, nil + } + + newEv := se.JSON() + + usersField := gjson.GetBytes(se.JSON(), "content.users") + if usersField.Exists() { + pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(se) + if err != nil { + return nil, err + } + + newPls := make(map[string]int64) + var userID *spec.UserID + for user, level := range pls.Users { + if eventFormat != FormatSyncFederation { + userID, err = userIDForSender(se.RoomID(), spec.SenderID(user)) + if err != nil { + return nil, err + } + user = userID.String() + } + newPls[user] = level + } + + var newPlBytes []byte + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(se.JSON(), "content.users", newPlBytes) + if err != nil { + return nil, err + } + } + + // do the same for prev content + prevUsersField := gjson.GetBytes(se.JSON(), "unsigned.prev_content.users") + if prevUsersField.Exists() { + prevContent := gjson.GetBytes(se.JSON(), "unsigned.prev_content") + if !prevContent.Exists() { + evNew, err := gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSON(newEv, false) + if err != nil { + return nil, err + } + + return evNew, err + } + pls := gomatrixserverlib.PowerLevelContent{} + err := json.Unmarshal([]byte(prevContent.Raw), &pls) + if err != nil { + return nil, err + } + + newPls := make(map[string]int64) + for user, level := range pls.Users { + if eventFormat != FormatSyncFederation { + userID, userErr := userIDForSender(se.RoomID(), spec.SenderID(user)) + if userErr != nil { + return nil, userErr + } + user = userID.String() + } + newPls[user] = level + } + + var newPlBytes []byte + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes) + if err != nil { + return nil, err + } + } + + evNew, err := gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSONWithEventID(se.EventID(), newEv, false) + if err != nil { + return nil, err + } + + return evNew, err +} diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 63c65b2af..662f9ea43 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -18,12 +18,77 @@ package synctypes import ( "bytes" "encoding/json" + "fmt" + "reflect" "testing" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) +func queryUserIDForSender(senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + +const testSenderID = "testSenderID" +const testUserID = "@test:localhost" + +type EventFieldsToVerify struct { + EventID string + Type string + OriginServerTS spec.Timestamp + StateKey *string + Content spec.RawJSON + Unsigned spec.RawJSON + Sender string + Depth int64 + PrevEvents []string + AuthEvents []string +} + +func verifyEventFields(t *testing.T, got EventFieldsToVerify, want EventFieldsToVerify) { + if got.EventID != want.EventID { + t.Errorf("ClientEvent.EventID: wanted %s, got %s", want.EventID, got.EventID) + } + if got.OriginServerTS != want.OriginServerTS { + t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", want.OriginServerTS, got.OriginServerTS) + } + if got.StateKey == nil && want.StateKey != nil { + t.Errorf("ClientEvent.StateKey: no state key present when one was wanted: %s", *want.StateKey) + } + if got.StateKey != nil && want.StateKey == nil { + t.Errorf("ClientEvent.StateKey: state key present when one was not wanted: %s", *got.StateKey) + } + if got.StateKey != nil && want.StateKey != nil && *got.StateKey != *want.StateKey { + t.Errorf("ClientEvent.StateKey: wanted %s, got %s", *want.StateKey, *got.StateKey) + } + if got.Type != want.Type { + t.Errorf("ClientEvent.Type: wanted %s, got %s", want.Type, got.Type) + } + if !bytes.Equal(got.Content, want.Content) { + t.Errorf("ClientEvent.Content: wanted %s, got %s", string(want.Content), string(got.Content)) + } + if !bytes.Equal(got.Unsigned, want.Unsigned) { + t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(want.Unsigned), string(got.Unsigned)) + } + if got.Sender != want.Sender { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", want.Sender, got.Sender) + } + if got.Depth != want.Depth { + t.Errorf("ClientEvent.Depth: wanted %d, got %d", want.Depth, got.Depth) + } + if !reflect.DeepEqual(got.PrevEvents, want.PrevEvents) { + t.Errorf("ClientEvent.PrevEvents: wanted %v, got %v", want.PrevEvents, got.PrevEvents) + } + if !reflect.DeepEqual(got.AuthEvents, want.AuthEvents) { + t.Errorf("ClientEvent.AuthEvents: wanted %v, got %v", want.AuthEvents, got.AuthEvents) + } +} + func TestToClientEvent(t *testing.T) { // nolint: gocyclo ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ "type": "m.room.name", @@ -49,28 +114,33 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo t.Fatalf("failed to create userID: %s", err) } sk := "" - ce := ToClientEvent(ev, FormatAll, *userID, &sk) - if ce.EventID != ev.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) - } - if ce.OriginServerTS != ev.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev.OriginServerTS(), ce.OriginServerTS) - } - if ce.StateKey == nil || *ce.StateKey != "" { - t.Errorf("ClientEvent.StateKey: wanted '', got %v", ce.StateKey) - } - if ce.Type != ev.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev.Type(), ce.Type) - } - if !bytes.Equal(ce.Content, ev.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev.Content()), string(ce.Content)) - } - if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) - } - if ce.Sender != userID.String() { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender) + ce, err := ToClientEvent(ev, FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(senderID) + }) + if err != nil { + t.Fatalf("failed to create ClientEvent: %s", err) } + + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: userID.String(), + }) + j, err := json.Marshal(ce) if err != nil { t.Fatalf("failed to Marshal ClientEvent: %s", err) @@ -104,13 +174,388 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create Event: %s", err) } + ce, err := ToClientEvent(ev, FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(senderID) + }) + if err != nil { + t.Fatalf("failed to create ClientEvent: %s", err) + } + if ce.RoomID != "" { + t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) + } +} + +func TestToClientEventFormatSyncFederation(t *testing.T) { // nolint: gocyclo + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV10).NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "", + "event_id": "$test:localhost", + "room_id": "!test:localhost", + "sender": "@test:localhost", + "content": { + "name": "Hello World" + }, + "origin_server_ts": 123456, + "unsigned": { + "prev_content": { + "name": "Goodbye World" + } + }, + "depth": 8, + "prev_events": [ + "$f597Tp0Mm1PPxEgiprzJc2cZAjVhxCxACOGuwJb33Oo" + ], + "auth_events": [ + "$Bj0ZGgX6VTqAQdqKH4ZG3l6rlbxY3rZlC5D3MeuK1OQ", + "$QsMs6A1PUVUhgSvmHBfpqEYJPgv4DXt96r8P2AK7iXQ", + "$tBteKtlnFiwlmPJsv0wkKTMEuUVWpQH89H7Xskxve1Q" + ] + }`), false) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } userID, err := spec.NewUserID("@test:localhost", true) if err != nil { t.Fatalf("failed to create userID: %s", err) } sk := "" - ce := ToClientEvent(ev, FormatSync, *userID, &sk) - if ce.RoomID != "" { - t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) + ce, err := ToClientEvent(ev, FormatSyncFederation, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(senderID) + }) + if err != nil { + t.Fatalf("failed to create ClientEvent: %s", err) } + + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + Depth: ce.Depth, + PrevEvents: ce.PrevEvents, + AuthEvents: ce.AuthEvents, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: userID.String(), + Depth: ev.Depth(), + PrevEvents: ev.PrevEventIDs(), + AuthEvents: ev.AuthEventIDs(), + }) +} + +func userIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "unknownSenderID" { + return nil, fmt.Errorf("Cannot find userID") + } + return spec.NewUserID(testUserID, true) +} + +func TestToClientEventsFormatSyncFederation(t *testing.T) { // nolint: gocyclo + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "testSenderID", + "event_id": "$test:localhost", + "room_id": "!test:localhost", + "sender": "testSenderID", + "content": { + "name": "Hello World" + }, + "origin_server_ts": 123456, + "unsigned": { + "prev_content": { + "name": "Goodbye World" + } + }, + "depth": 8, + "prev_events": [ + "$f597Tp0Mm1PPxEgiprzJc2cZAjVhxCxACOGuwJb33Oo" + ], + "auth_events": [ + "$Bj0ZGgX6VTqAQdqKH4ZG3l6rlbxY3rZlC5D3MeuK1OQ", + "$QsMs6A1PUVUhgSvmHBfpqEYJPgv4DXt96r8P2AK7iXQ", + "$tBteKtlnFiwlmPJsv0wkKTMEuUVWpQH89H7Xskxve1Q" + ] + }`), false) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } + ev2, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "testSenderID", + "event_id": "$test2:localhost", + "room_id": "!test:localhost", + "sender": "testSenderID", + "content": { + "name": "Hello World 2" + }, + "origin_server_ts": 1234567, + "unsigned": { + "prev_content": { + "name": "Goodbye World 2" + }, + "prev_sender": "testSenderID" + }, + "depth": 9, + "prev_events": [ + "$f597Tp0Mm1PPxEgiprzJc2cZAjVhxCxACOGuwJb33Oo" + ], + "auth_events": [ + "$Bj0ZGgX6VTqAQdqKH4ZG3l6rlbxY3rZlC5D3MeuK1OQ", + "$QsMs6A1PUVUhgSvmHBfpqEYJPgv4DXt96r8P2AK7iXQ", + "$tBteKtlnFiwlmPJsv0wkKTMEuUVWpQH89H7Xskxve1Q" + ] + }`), false) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } + + clientEvents := ToClientEvents([]gomatrixserverlib.PDU{ev, ev2}, FormatSyncFederation, userIDForSender) + ce := clientEvents[0] + sk := testSenderID + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + Depth: ce.Depth, + PrevEvents: ce.PrevEvents, + AuthEvents: ce.AuthEvents, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: testSenderID, + Depth: ev.Depth(), + PrevEvents: ev.PrevEventIDs(), + AuthEvents: ev.AuthEventIDs(), + }) + + ce2 := clientEvents[1] + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce2.EventID, + Type: ce2.Type, + OriginServerTS: ce2.OriginServerTS, + StateKey: ce2.StateKey, + Content: ce2.Content, + Unsigned: ce2.Unsigned, + Sender: ce2.Sender, + Depth: ce2.Depth, + PrevEvents: ce2.PrevEvents, + AuthEvents: ce2.AuthEvents, + }, + EventFieldsToVerify{ + EventID: ev2.EventID(), + Type: ev2.Type(), + OriginServerTS: ev2.OriginServerTS(), + StateKey: &sk, + Content: ev2.Content(), + Unsigned: ev2.Unsigned(), + Sender: testSenderID, + Depth: ev2.Depth(), + PrevEvents: ev2.PrevEventIDs(), + AuthEvents: ev2.AuthEventIDs(), + }) +} + +func TestToClientEventsFormatSync(t *testing.T) { // nolint: gocyclo + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "testSenderID", + "event_id": "$test:localhost", + "room_id": "!test:localhost", + "sender": "testSenderID", + "content": { + "name": "Hello World" + }, + "origin_server_ts": 123456, + "unsigned": { + "prev_content": { + "name": "Goodbye World" + } + } + }`), false) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } + ev2, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "testSenderID", + "event_id": "$test2:localhost", + "room_id": "!test:localhost", + "sender": "testSenderID", + "content": { + "name": "Hello World 2" + }, + "origin_server_ts": 1234567, + "unsigned": { + "prev_content": { + "name": "Goodbye World 2" + }, + "prev_sender": "testSenderID" + }, + "depth": 9 + }`), false) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } + + clientEvents := ToClientEvents([]gomatrixserverlib.PDU{ev, ev2}, FormatSync, userIDForSender) + ce := clientEvents[0] + sk := testUserID + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: testUserID, + }) + + var prev PrevEventRef + prev.PrevContent = []byte(`{"name": "Goodbye World 2"}`) + prev.PrevSenderID = testUserID + expectedUnsigned, _ := json.Marshal(prev) + + ce2 := clientEvents[1] + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce2.EventID, + Type: ce2.Type, + OriginServerTS: ce2.OriginServerTS, + StateKey: ce2.StateKey, + Content: ce2.Content, + Unsigned: ce2.Unsigned, + Sender: ce2.Sender, + }, + EventFieldsToVerify{ + EventID: ev2.EventID(), + Type: ev2.Type(), + OriginServerTS: ev2.OriginServerTS(), + StateKey: &sk, + Content: ev2.Content(), + Unsigned: expectedUnsigned, + Sender: testUserID, + }) +} + +func TestToClientEventsFormatSyncUnknownPrevSender(t *testing.T) { // nolint: gocyclo + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "testSenderID", + "event_id": "$test:localhost", + "room_id": "!test:localhost", + "sender": "testSenderID", + "content": { + "name": "Hello World" + }, + "origin_server_ts": 123456, + "unsigned": { + "prev_content": { + "name": "Goodbye World" + } + } + }`), false) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } + ev2, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.name", + "state_key": "testSenderID", + "event_id": "$test2:localhost", + "room_id": "!test:localhost", + "sender": "testSenderID", + "content": { + "name": "Hello World 2" + }, + "origin_server_ts": 1234567, + "unsigned": { + "prev_content": { + "name": "Goodbye World 2" + }, + "prev_sender": "unknownSenderID" + }, + "depth": 9 + }`), false) + if err != nil { + t.Fatalf("failed to create Event: %s", err) + } + + clientEvents := ToClientEvents([]gomatrixserverlib.PDU{ev, ev2}, FormatSync, userIDForSender) + ce := clientEvents[0] + sk := testUserID + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: testUserID, + }) + + var prev PrevEventRef + prev.PrevContent = []byte(`{"name": "Goodbye World 2"}`) + prev.PrevSenderID = "unknownSenderID" + expectedUnsigned, _ := json.Marshal(prev) + + ce2 := clientEvents[1] + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce2.EventID, + Type: ce2.Type, + OriginServerTS: ce2.OriginServerTS, + StateKey: ce2.StateKey, + Content: ce2.Content, + Unsigned: ce2.Unsigned, + Sender: ce2.Sender, + }, + EventFieldsToVerify{ + EventID: ev2.EventID(), + Type: ev2.Type(), + OriginServerTS: ev2.OriginServerTS(), + StateKey: &sk, + Content: ev2.Content(), + Unsigned: expectedUnsigned, + Sender: testUserID, + }) } diff --git a/syncapi/synctypes/filter.go b/syncapi/synctypes/filter.go index c994ddb96..8998d4433 100644 --- a/syncapi/synctypes/filter.go +++ b/syncapi/synctypes/filter.go @@ -78,9 +78,14 @@ type RoomEventFilter struct { ContainsURL *bool `json:"contains_url,omitempty"` } +const ( + EventFormatClient = "client" + EventFormatFederation = "federation" +) + // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { - if filter.EventFormat != "" && filter.EventFormat != "client" && filter.EventFormat != "federation" { + if filter.EventFormat != "" && filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") } return nil diff --git a/syncapi/types/types.go b/syncapi/types/types.go index cb3c362d5..bca11855c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -15,6 +15,7 @@ package types import ( + "context" "encoding/json" "errors" "fmt" @@ -339,13 +340,6 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { return token, nil } -// PrevEventRef represents a reference to a previous event in a state event upgrade -type PrevEventRef struct { - PrevContent json.RawMessage `json:"prev_content"` - ReplacesState string `json:"replaces_state"` - PrevSenderID string `json:"prev_sender"` -} - type DeviceLists struct { Changed []string `json:"changed,omitempty"` Left []string `json:"left,omitempty"` @@ -539,7 +533,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse { +func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *types.HeaderedEvent, eventFormat synctypes.ClientEventFormat) (*InviteResponse, error) { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -547,18 +541,42 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey // If there is then unmarshal it into the response. This will contain the // partial room state such as join rules, room name etc. if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { - _ = json.Unmarshal([]byte(inviteRoomState.Raw), &res.InviteState.Events) + if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation { + updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, inviteRoomState, event.PDU, event.RoomID(), eventFormat) + if err != nil { + return nil, err + } + _ = json.Unmarshal(updatedInvite, &res.InviteState.Events) + } else { + _ = json.Unmarshal([]byte(inviteRoomState.Raw), &res.InviteState.Events) + } + } + + // Clear unsigned so it doesn't have pseudoIDs converted during ToClientEvent + eventNoUnsigned, err := event.SetUnsigned(nil) + if err != nil { + return nil, err } // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey) + inviteEvent, err := synctypes.ToClientEvent(eventNoUnsigned, eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if err != nil { + return nil, err + } + + // Ensure unsigned field is empty so it isn't marshalled into the final JSON inviteEvent.Unsigned = nil - if ev, err := json.Marshal(inviteEvent); err == nil { + + if ev, err := json.Marshal(*inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) } - return &res + return &res, nil } // LeaveResponse represents a /sync response for a room which is under the 'leave' key. diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index c1b7f70bd..35e1882cb 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -1,6 +1,7 @@ package types import ( + "context" "encoding/json" "reflect" "testing" @@ -11,8 +12,19 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +type FakeRoomserverAPI struct{} + +func (f *FakeRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + +func (f *FakeRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) { + sender := spec.SenderID(userID.String()) + return &sender, nil } func TestSyncTokens(t *testing.T) { @@ -61,25 +73,18 @@ func TestNewInviteResponse(t *testing.T) { t.Fatal(err) } - sender, err := spec.NewUserID("@neilalexander:matrix.org", true) + rsAPI := FakeRoomserverAPI{} + res, err := NewInviteResponse(context.Background(), &rsAPI, &types.HeaderedEvent{PDU: ev}, synctypes.FormatSync) if err != nil { t.Fatal(err) } - skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true) - if err != nil { - t.Fatal(err) - } - skString := skUserID.String() - sk := &skString - - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk) j, err := json.Marshal(res) if err != nil { t.Fatal(err) } if string(j) != expected { - t.Fatalf("Invite response didn't contain correct info") + t.Fatalf("Invite response didn't contain correct info, \nexpected: %s \ngot: %s", expected, string(j)) } } diff --git a/sytest-whitelist b/sytest-whitelist index c61e0bc3c..60ba02302 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -783,4 +783,8 @@ Invited user can reject invite for empty room Invited user can reject local invite after originator leaves Guest users can join guest_access rooms Forgotten room messages cannot be paginated -Local device key changes get to remote servers with correct prev_id \ No newline at end of file +Local device key changes get to remote servers with correct prev_id +HS provides query metadata +HS can provide query metadata on a single protocol +Invites over federation are correctly pushed +Invites over federation are correctly pushed with name \ No newline at end of file diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index a88b2129d..d5baa074c 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -92,30 +92,43 @@ func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { msg := msgs[0] // Guaranteed to exist if onMessage is called // Only handle events we care about - if rsapi.OutputType(msg.Header.Get(jetstream.RoomEventType)) != rsapi.OutputTypeNewRoomEvent { - return true - } - var output rsapi.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } - event := output.NewRoomEvent.Event - if event == nil { - log.Errorf("userapi consumer: expected event") + + var event *rstypes.HeaderedEvent + var isNewRoomEvent bool + switch rsapi.OutputType(msg.Header.Get(jetstream.RoomEventType)) { + case rsapi.OutputTypeNewRoomEvent: + isNewRoomEvent = true + fallthrough + case rsapi.OutputTypeNewInviteEvent: + var output rsapi.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") + return true + } + if isNewRoomEvent { + event = output.NewRoomEvent.Event + } else { + event = output.NewInviteEvent.Event + } + + if event == nil { + log.Errorf("userapi consumer: expected event") + return true + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "event_type": event.Type(), + }).Tracef("Received message from roomserver: %#v", output) + default: return true } if s.cfg.Matrix.ReportStats.Enabled { - go s.storeMessageStats(ctx, event.Type(), string(event.SenderID()), event.RoomID()) + go s.storeMessageStats(ctx, event.Type(), string(event.SenderID()), event.RoomID().String()) } - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "event_type": event.Type(), - }).Tracef("Received message from roomserver: %#v", output) - metadata, err := msg.Metadata() if err != nil { return true @@ -294,36 +307,21 @@ func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRo } func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rstypes.HeaderedEvent, streamPos uint64) error { - members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) + members, roomSize, err := s.localRoomMembers(ctx, event.RoomID().String()) if err != nil { return fmt.Errorf("s.localRoomMembers: %w", err) } switch { case event.Type() == spec.MRoomMember: - sender := spec.UserID{} - validRoomID, roomErr := spec.NewRoomID(event.RoomID()) - if roomErr != nil { - return roomErr + cevent, clientEvErr := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if clientEvErr != nil { + return clientEvErr } - userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) - if queryErr == nil && userID != nil { - sender = *userID - } - - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*sk)) - if queryErr == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } else { - return fmt.Errorf("queryUserIDForSender: userID unknown for %s", *sk) - } - } - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk) var member *localMembership - member, err = newLocalMembership(&cevent) + member, err = newLocalMembership(cevent) if err != nil { return fmt.Errorf("newLocalMembership: %w", err) } @@ -334,7 +332,7 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst } case event.Type() == "m.room.tombstone" && event.StateKeyEquals(""): // Handle room upgrades - oldRoomID := event.RoomID() + oldRoomID := event.RoomID().String() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str if err = s.handleRoomUpgrade(ctx, oldRoomID, newRoomID, members, roomSize); err != nil { // while inconvenient, this shouldn't stop us from sending push notifications @@ -351,7 +349,7 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst log.WithFields(log.Fields{ "event_id": event.EventID(), - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "num_members": len(members), "room_size": roomSize, }).Tracef("Notifying members") @@ -463,8 +461,21 @@ func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *rstypes.H } } + // Special case for invites, as we don't store any "current state" for these events, + // we need to make sure that, if present, the m.room.name is sent as well. + if event.Type() == spec.MRoomMember && + gjson.GetBytes(event.Content(), "membership").Str == "invite" { + invState := gjson.GetBytes(event.JSON(), "unsigned.invite_room_state") + for _, ev := range invState.Array() { + if ev.Get("type").Str == spec.MRoomName { + name := ev.Get("content.name").Str + return name, nil + } + } + } + req := &rsapi.QueryCurrentStateRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple}, } var res rsapi.QueryCurrentStateResponse @@ -532,7 +543,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { log.WithFields(log.Fields{ "event_id": event.EventID(), - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "localpart": mem.Localpart, }).Tracef("Push rule evaluation rejected the event") return nil @@ -542,44 +553,32 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype if err != nil { return fmt.Errorf("s.localPushDevices: %w", err) } - - sender := spec.UserID{} - validRoomID, err := spec.NewRoomID(event.RoomID()) + clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return err } - userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) - if queryErr == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk), + Event: *clientEvent, // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it // to "work", but they only use a single device. ProfileTag: profileTag, - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), TS: spec.AsTimestamp(time.Now()), } if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil { return fmt.Errorf("s.db.InsertNotification: %w", err) } - if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { + if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID().String()); err != nil { return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err) } @@ -591,7 +590,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype log.WithFields(log.Fields{ "event_id": event.EventID(), - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "localpart": mem.Localpart, "num_urls": len(devicesByURLAndFormat), "num_unread": userNumUnreadNotifs, @@ -648,11 +647,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { user := "" - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return nil, err - } - sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err == nil { user = sender.String() } @@ -686,7 +681,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * ctx: ctx, rsAPI: s.rsAPI, mem: mem, - roomID: event.RoomID(), + roomID: event.RoomID().String(), roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) @@ -704,7 +699,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * log.WithFields(log.Fields{ "event_id": event.EventID(), - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "localpart": mem.Localpart, "rule_id": rule.RuleID, }).Trace("Matched a push rule") @@ -793,16 +788,12 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes }, Devices: devices, EventID: event.EventID(), - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), }, } default: - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return nil, err - } - sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) return nil, err @@ -816,7 +807,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes Devices: devices, EventID: event.EventID(), ID: event.EventID(), - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), RoomName: roomName, Sender: sender.String(), Type: event.Type(), @@ -830,19 +821,13 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) return nil, err } - roomID, err := spec.NewRoomID(event.RoomID()) + localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID) if err != nil { - logger.WithError(err).Errorf("event roomID is invalid %s", event.RoomID()) - return nil, err - } - - localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) - if err != nil { - logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) + logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID().String()) return nil, err } else if localSender == nil { - logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) - return nil, fmt.Errorf("no sender ID for user %s in %s", userID.String(), roomID.String()) + logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID().String()) + return nil, fmt.Errorf("no sender ID for user %s in %s", userID.String(), event.RoomID().String()) } if event.StateKey() != nil && *event.StateKey() == string(*localSender) { req.Notification.UserIsTarget = true diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 4305c13a9..4e3c2671a 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -563,12 +563,15 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAccessTokenRequest, res *api.QueryAccessTokenResponse) error { if req.AppServiceUserID != "" { appServiceDevice, err := a.queryAppServiceToken(ctx, req.AccessToken, req.AppServiceUserID) - if err != nil { - res.Err = err.Error() - } - res.Device = appServiceDevice + if err != nil || appServiceDevice != nil { + if err != nil { + res.Err = err.Error() + } + res.Device = appServiceDevice - return nil + return nil + } + // If the provided token wasn't an as_token (both err and appServiceDevice are nil), continue with normal auth. } device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken) if err != nil { diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 3017069bc..2ea978d69 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -23,6 +23,14 @@ import ( userUtil "github.com/matrix-org/dendrite/userapi/util" ) +func queryUserIDForSender(senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + func TestNotifyUserCountsAsync(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -100,13 +108,14 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Insert a dummy event - sender, err := spec.NewUserID(alice.ID, true) + ev, err := synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(senderID) + }) if err != nil { t.Error(err) } - sk := "" if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk), + Event: *ev, }); err != nil { t.Error(err) }