Merge 0.13.4 upstream

This commit is contained in:
Dan Peleg 2023-10-30 22:28:31 +02:00
parent d44dc3bd98
commit 12d623149f
137 changed files with 3746 additions and 2290 deletions

View file

@ -128,7 +128,7 @@ func (s *OutputRoomEventConsumer) onMessage(
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
newEventID := output.NewRoomEvent.Event.EventID() newEventID := output.NewRoomEvent.Event.EventID()
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: output.NewRoomEvent.Event.RoomID(), RoomID: output.NewRoomEvent.Event.RoomID().String(),
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}
@ -236,11 +236,7 @@ func (s *appserviceState) backoffAndPause(err error) error {
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
user := "" user := ""
validRoomID, err := spec.NewRoomID(event.RoomID()) userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
return false
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil { if err == nil {
user = userID.String() user = userID.String()
} }
@ -250,7 +246,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
return false return false
case appservice.IsInterestedInUserID(user): case appservice.IsInterestedInUserID(user):
return true return true
case appservice.IsInterestedInRoomID(event.RoomID()): case appservice.IsInterestedInRoomID(event.RoomID().String()):
return true return true
} }
@ -261,7 +257,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
} }
// Check all known room aliases of the room the event came from // Check all known room aliases of the room the event came from
queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID()} queryReq := api.GetAliasesForRoomIDRequest{RoomID: event.RoomID().String()}
var queryRes api.GetAliasesForRoomIDResponse var queryRes api.GetAliasesForRoomIDResponse
if err := s.rsAPI.GetAliasesForRoomID(ctx, &queryReq, &queryRes); err == nil { if err := s.rsAPI.GetAliasesForRoomID(ctx, &queryReq, &queryRes); err == nil {
for _, alias := range queryRes.Aliases { for _, alias := range queryRes.Aliases {
@ -272,7 +268,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": appservice.ID, "appservice": appservice.ID,
"room_id": event.RoomID(), "room_id": event.RoomID().String(),
}).WithError(err).Errorf("Unable to get aliases for room") }).WithError(err).Errorf("Unable to get aliases for room")
} }
@ -288,7 +284,7 @@ func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, e
// until we have a lighter way of checking the state before the event that // until we have a lighter way of checking the state before the event that
// doesn't involve state res, then this is probably OK. // doesn't involve state res, then this is probably OK.
membershipReq := &api.QueryMembershipsForRoomRequest{ membershipReq := &api.QueryMembershipsForRoomRequest{
RoomID: event.RoomID(), RoomID: event.RoomID().String(),
JoinedOnly: true, JoinedOnly: true,
} }
membershipRes := &api.QueryMembershipsForRoomResponse{} membershipRes := &api.QueryMembershipsForRoomResponse{}
@ -317,7 +313,7 @@ func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, e
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"appservice": appservice.ID, "appservice": appservice.ID,
"room_id": event.RoomID(), "room_id": event.RoomID().String(),
}).WithError(err).Errorf("Unable to get membership for room") }).WithError(err).Errorf("Unable to get membership for room")
} }
return false return false

View file

@ -181,13 +181,39 @@ func SetLocalAlias(
return *resErr return *resErr
} }
queryReq := roomserverAPI.SetRoomAliasRequest{ roomID, err := spec.NewRoomID(r.RoomID)
UserID: device.UserID, if err != nil {
RoomID: r.RoomID, return util.JSONResponse{
Alias: alias, Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid room ID"),
}
} }
var queryRes roomserverAPI.SetRoomAliasResponse
if err := rsAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *roomID, *userID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QuerySenderIDForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
} else if senderID == nil {
util.GetLogger(req.Context()).WithField("roomID", *roomID).WithField("userID", *userID).Error("Sender ID not found")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
aliasAlreadyExists, err := rsAPI.SetRoomAlias(req.Context(), *senderID, *roomID, alias)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -195,7 +221,7 @@ func SetLocalAlias(
} }
} }
if queryRes.AliasExists { if aliasAlreadyExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusConflict, Code: http.StatusConflict,
JSON: spec.Unknown("The alias " + alias + " already exists."), JSON: spec.Unknown("The alias " + alias + " already exists."),
@ -240,6 +266,31 @@ func RemoveLocalAlias(
JSON: spec.NotFound("The alias does not exist."), JSON: spec.NotFound("The alias does not exist."),
} }
} }
// This seems like the kind of auth check that should be done in the roomserver, but
// if this check fails (user is not in the room), then there will be no SenderID for the user
// for pseudo-ID rooms - it will just return "". However, we can't use lack of a sender ID
// as meaning they are not in the room, since lacking a sender ID could be caused by other bugs.
// TODO: maybe have QuerySenderIDForUser return richer errors?
var queryResp roomserverAPI.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: validRoomID.String(),
UserID: *userID,
}, &queryResp)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.QueryMembershipForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
if !queryResp.IsInRoom {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You do not have permission to remove this alias."),
}
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID) deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -247,28 +298,31 @@ func RemoveLocalAlias(
JSON: spec.NotFound("The alias does not exist."), JSON: spec.NotFound("The alias does not exist."),
} }
} }
// TODO: how to handle this case? missing user/room keys seem to be a whole new class of errors
queryReq := roomserverAPI.RemoveRoomAliasRequest{ if deviceSenderID == nil {
Alias: alias,
SenderID: deviceSenderID,
}
var queryRes roomserverAPI.RemoveRoomAliasResponse
if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.RemoveRoomAlias failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.Unknown("internal server error"),
} }
} }
if !queryRes.Found { aliasFound, aliasRemoved, err := rsAPI.RemoveRoomAlias(req.Context(), *deviceSenderID, alias)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.RemoveRoomAlias failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
if !aliasFound {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: spec.NotFound("The alias does not exist."), JSON: spec.NotFound("The alias does not exist."),
} }
} }
if !queryRes.Removed { if !aliasRemoved {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("You do not have permission to remove this alias."), JSON: spec.Forbidden("You do not have permission to remove this alias."),
@ -337,7 +391,7 @@ func SetVisibility(
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil || senderID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.Unknown("failed to find senderID for this user"), JSON: spec.Unknown("failed to find senderID for this user"),
@ -368,7 +422,7 @@ func SetVisibility(
// NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event
power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU)
if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { if power.UserLevel(*senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"), JSON: spec.Forbidden("userID doesn't have power level to change visibility"),

View file

@ -33,23 +33,36 @@ func GetJoinedRooms(
device *userapi.Device, device *userapi.Device,
rsAPI api.ClientRoomserverAPI, rsAPI api.ClientRoomserverAPI,
) util.JSONResponse { ) util.JSONResponse {
var res api.QueryRoomsForUserResponse deviceUserID, err := spec.NewUserID(device.UserID, true)
err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ if err != nil {
UserID: device.UserID, util.GetLogger(req.Context()).WithError(err).Error("Invalid device user ID")
WantMembership: "join", return util.JSONResponse{
}, &res) Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
rooms, err := rsAPI.QueryRoomsForUser(req.Context(), *deviceUserID, "join")
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.Unknown("internal server error"),
} }
} }
if res.RoomIDs == nil {
res.RoomIDs = []string{} var roomIDStrs []string
if rooms == nil {
roomIDStrs = []string{}
} else {
roomIDStrs = make([]string, len(rooms))
for i, roomID := range rooms {
roomIDStrs[i] = roomID.String()
}
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getJoinedRoomsResponse{res.RoomIDs}, JSON: getJoinedRoomsResponse{roomIDStrs},
} }
} }

View file

@ -71,7 +71,7 @@ func SendBan(
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil || senderID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
@ -87,7 +87,7 @@ func SendBan(
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
allowedToBan := pl.UserLevel(senderID) >= pl.Ban allowedToBan := pl.UserLevel(*senderID) >= pl.Ban
if !allowedToBan { if !allowedToBan {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -169,7 +169,7 @@ func SendKick(
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil || senderID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
@ -185,7 +185,7 @@ func SendKick(
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }
allowedToKick := pl.UserLevel(senderID) >= pl.Kick allowedToKick := pl.UserLevel(*senderID) >= pl.Kick
if !allowedToKick { if !allowedToKick {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -337,22 +337,55 @@ func sendInvite(
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
) (util.JSONResponse, error) { ) (util.JSONResponse, error) {
event, err := buildMembershipEvent( validRoomID, err := spec.NewRoomID(roomID)
ctx, userID, reason, profileAPI, device, spec.Invite, if err != nil {
roomID, false, cfg, evTime, rsAPI, asAPI, return util.JSONResponse{
) Code: http.StatusBadRequest,
JSON: spec.InvalidParam("RoomID is invalid"),
}, err
}
inviter, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}, err
}
invitee, err := spec.NewUserID(userID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("UserID is invalid"),
}, err
}
profile, err := loadProfile(ctx, userID, cfg, profileAPI, asAPI)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}, err
}
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
}, err }, err
} }
err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
Event: event, InviteInput: roomserverAPI.InviteInput{
RoomID: *validRoomID,
Inviter: *inviter,
Invitee: *invitee,
DisplayName: profile.DisplayName,
AvatarURL: profile.AvatarURL,
Reason: reason,
IsDirect: false,
KeyID: identity.KeyID,
PrivateKey: identity.PrivateKey,
EventTime: evTime,
},
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
RoomVersion: event.Version(),
SendAsServer: string(device.UserDomain()), SendAsServer: string(device.UserDomain()),
}) })
@ -443,6 +476,8 @@ func buildMembershipEvent(
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID)
if err != nil { if err != nil {
return nil, err return nil, err
} else if senderID == nil {
return nil, fmt.Errorf("no sender ID for %s in %s", *userID, *validRoomID)
} }
targetID, err := spec.NewUserID(targetUserID, true) targetID, err := spec.NewUserID(targetUserID, true)
@ -452,6 +487,8 @@ func buildMembershipEvent(
targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID) targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID)
if err != nil { if err != nil {
return nil, err return nil, err
} else if targetSenderID == nil {
return nil, fmt.Errorf("no sender ID for %s in %s", *targetID, *validRoomID)
} }
identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *userID) identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *userID)
@ -459,8 +496,8 @@ func buildMembershipEvent(
return nil, err return nil, err
} }
return buildMembershipEventDirect(ctx, targetSenderID, reason, profile.DisplayName, profile.AvatarURL, return buildMembershipEventDirect(ctx, *targetSenderID, reason, profile.DisplayName, profile.AvatarURL,
senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) *senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI)
} }
// loadProfile lookups the profile of a given user from the database and returns // loadProfile lookups the profile of a given user from the database and returns

View file

@ -16,6 +16,7 @@ package routing
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"time" "time"
@ -250,11 +251,15 @@ func updateProfile(
profile *authtypes.Profile, profile *authtypes.Profile,
userID string, evTime time.Time, userID string, evTime time.Time,
) (util.JSONResponse, error) { ) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse deviceUserID, err := spec.NewUserID(device.UserID, true)
err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ if err != nil {
UserID: device.UserID, return util.JSONResponse{
WantMembership: "join", Code: http.StatusInternalServerError,
}, &res) JSON: spec.Unknown("internal server error"),
}, err
}
rooms, err := rsAPI.QueryRoomsForUser(ctx, *deviceUserID, "join")
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
return util.JSONResponse{ return util.JSONResponse{
@ -263,6 +268,11 @@ func updateProfile(
}, err }, err
} }
roomIDStrs := make([]string, len(rooms))
for i, room := range rooms {
roomIDStrs[i] = room.String()
}
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
@ -273,7 +283,7 @@ func updateProfile(
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, evTime, rsAPI, ctx, roomIDStrs, *profile, userID, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -362,8 +372,10 @@ func buildMembershipEvents(
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} else if senderID == nil {
return nil, fmt.Errorf("sender ID not found for %s in %s", *fullUserID, *validRoomID)
} }
senderIDString := string(senderID) senderIDString := string(*senderID)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: senderIDString, SenderID: senderIDString,
RoomID: roomID, RoomID: roomID,

View file

@ -34,7 +34,8 @@ import (
) )
type redactionContent struct { type redactionContent struct {
Reason string `json:"reason"` Reason string `json:"reason"`
Redacts string `json:"redacts"`
} }
type redactionResponse struct { type redactionResponse struct {
@ -74,6 +75,16 @@ func SendRedaction(
return *resErr return *resErr
} }
// if user is member of room, and sender ID is nil, then this user doesn't have a pseudo ID for some reason,
// which is unexpected.
if senderID == nil {
util.GetLogger(req.Context()).WithField("userID", *deviceUserID).WithField("roomID", roomID).Error("missing sender ID for user, despite having membership")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
if txnID != nil { if txnID != nil {
// Try to fetch response from transactionsCache // Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
@ -88,7 +99,7 @@ func SendRedaction(
JSON: spec.NotFound("unknown event ID"), // TODO: is it ok to leak existence? JSON: spec.NotFound("unknown event ID"), // TODO: is it ok to leak existence?
} }
} }
if ev.RoomID() != roomID { if ev.RoomID().String() != roomID {
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
JSON: spec.NotFound("cannot redact event in another room"), JSON: spec.NotFound("cannot redact event in another room"),
@ -98,7 +109,7 @@ func SendRedaction(
// "Users may redact their own events, and any user with a power level greater than or equal // "Users may redact their own events, and any user with a power level greater than or equal
// to the redact power level of the room may redact events there" // to the redact power level of the room may redact events there"
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
allowedToRedact := ev.SenderID() == senderID allowedToRedact := ev.SenderID() == *senderID
if !allowedToRedact { if !allowedToRedact {
plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,
@ -119,7 +130,7 @@ func SendRedaction(
), ),
} }
} }
allowedToRedact = pl.UserLevel(senderID) >= pl.Redact allowedToRedact = pl.UserLevel(*senderID) >= pl.Redact
} }
if !allowedToRedact { if !allowedToRedact {
return util.JSONResponse{ return util.JSONResponse{
@ -136,11 +147,16 @@ func SendRedaction(
// create the new event and set all the fields we can // create the new event and set all the fields we can
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(*senderID),
RoomID: roomID, RoomID: roomID,
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Redacts: eventID, Redacts: eventID,
} }
// Room version 11 expects the "redacts" field on the
// content field, so add it here as well
r.Redacts = eventID
err = proto.SetContent(r) err = proto.SetContent(r)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed")

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/synctypes"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -92,6 +93,30 @@ func SendEvent(
} }
} }
// Translate user ID state keys to room keys in pseudo ID rooms
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil {
parsedRoomID, innerErr := spec.NewRoomID(roomID)
if innerErr != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid room ID"),
}
}
newStateKey, innerErr := synctypes.FromClientStateKey(*parsedRoomID, *stateKey, func(roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
return rsAPI.QuerySenderIDForUser(req.Context(), roomID, userID)
})
if innerErr != nil {
// TODO: work out better logic for failure cases (e.g. sender ID not found)
util.GetLogger(req.Context()).WithError(innerErr).Error("synctypes.FromClientStateKey failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
stateKey = newStateKey
}
// create a mutex for the specific user in the specific room // create a mutex for the specific user in the specific room
// this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order // this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order
userID := device.UserID userID := device.UserID
@ -238,7 +263,11 @@ func SendEvent(
} }
func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error { func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error {
userMap := r["users"].(map[string]interface{}) users, ok := r["users"]
if !ok {
return nil
}
userMap := users.(map[string]interface{})
validRoomID, err := spec.NewRoomID(roomID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
return err return err
@ -251,8 +280,11 @@ func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID strin
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *uID) senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *uID)
if err != nil { if err != nil {
return err return err
} else if senderID == nil {
util.GetLogger(req.Context()).Warnf("sender ID not found for %s in %s", uID, *validRoomID)
continue
} }
userMap[string(senderID)] = level userMap[string(*senderID)] = level
delete(userMap, user) delete(userMap, user)
} }
r["users"] = userMap r["users"] = userMap
@ -316,14 +348,21 @@ func generateSendEvent(
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusInternalServerError,
JSON: spec.NotFound("Unable to find senderID for user"), JSON: spec.NotFound("internal server error"),
}
} else if senderID == nil {
// TODO: is it always the case that lack of a sender ID means they're not joined?
// And should this logic be deferred to the roomserver somehow?
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("not joined to room"),
} }
} }
// create the new event and set all the fields we can // create the new event and set all the fields we can
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(*senderID),
RoomID: roomID, RoomID: roomID,
Type: eventType, Type: eventType,
StateKey: stateKey, StateKey: stateKey,
@ -403,7 +442,7 @@ func generateSendEvent(
JSON: spec.BadJSON("Cannot unmarshal the event content."), JSON: spec.BadJSON("Cannot unmarshal the event content."),
} }
} }
if content["replacement_room"] == e.RoomID() { if content["replacement_room"] == e.RoomID().String() {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Cannot send tombstone event that points to the same room."), JSON: spec.InvalidParam("Cannot send tombstone event that points to the same room."),

View file

@ -28,7 +28,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
@ -95,34 +94,42 @@ func SendServerNotice(
} }
} }
userID, err := spec.NewUserID(r.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid user ID"),
}
}
// get rooms for specified user // get rooms for specified user
allUserRooms := []string{} allUserRooms := []spec.RoomID{}
userRooms := api.QueryRoomsForUserResponse{}
// Get rooms the user is either joined, invited or has left. // Get rooms the user is either joined, invited or has left.
for _, membership := range []string{"join", "invite", "leave"} { for _, membership := range []string{"join", "invite", "leave"} {
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ userRooms, queryErr := rsAPI.QueryRoomsForUser(ctx, *userID, membership)
UserID: r.UserID, if queryErr != nil {
WantMembership: membership,
}, &userRooms); err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
allUserRooms = append(allUserRooms, userRooms.RoomIDs...) allUserRooms = append(allUserRooms, userRooms...)
} }
// get rooms of the sender // get rooms of the sender
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) senderUserID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName), true)
senderRooms := api.QueryRoomsForUserResponse{} if err != nil {
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ return util.JSONResponse{
UserID: senderUserID, Code: http.StatusInternalServerError,
WantMembership: "join", JSON: spec.Unknown("internal server error"),
}, &senderRooms); err != nil { }
}
senderRooms, err := rsAPI.QueryRoomsForUser(ctx, *senderUserID, "join")
if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
// check if we have rooms in common // check if we have rooms in common
commonRooms := []string{} commonRooms := []spec.RoomID{}
for _, userRoomID := range allUserRooms { for _, userRoomID := range allUserRooms {
for _, senderRoomID := range senderRooms.RoomIDs { for _, senderRoomID := range senderRooms {
if userRoomID == senderRoomID { if userRoomID == senderRoomID {
commonRooms = append(commonRooms, senderRoomID) commonRooms = append(commonRooms, senderRoomID)
} }
@ -135,12 +142,12 @@ func SendServerNotice(
var ( var (
roomID string roomID string
roomVersion = version.DefaultRoomVersion() roomVersion = rsAPI.DefaultRoomVersion()
) )
// create a new room for the user // create a new room for the user
if len(commonRooms) == 0 { if len(commonRooms) == 0 {
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID) powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID.String())
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
pl, err := json.Marshal(powerLevelContent) pl, err := json.Marshal(powerLevelContent)
if err != nil { if err != nil {
@ -196,7 +203,7 @@ func SendServerNotice(
} }
} }
roomID = commonRooms[0] roomID = commonRooms[0].String()
membershipRes := api.QueryMembershipForUserResponse{} membershipRes := api.QueryMembershipForUserResponse{}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes)
if err != nil { if err != nil {

View file

@ -172,28 +172,16 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
} }
} }
for _, ev := range stateAfterRes.StateEvents { for _, ev := range stateAfterRes.StateEvents {
sender := spec.UserID{} clientEvent, err := synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
evRoomID, err := spec.NewRoomID(ev.RoomID()) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid") util.GetLogger(ctx).WithError(err).Error("Failed converting to ClientEvent")
continue continue
} }
userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := ev.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk), *clientEvent,
) )
} }
} }
@ -217,6 +205,37 @@ func OnIncomingStateTypeRequest(
var worldReadable bool var worldReadable bool
var wantLatestState bool var wantLatestState bool
roomVer, err := rsAPI.QueryRoomVersionForRoom(ctx, roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)),
}
}
// Translate user ID state keys to room keys in pseudo ID rooms
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs {
parsedRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.InvalidParam("invalid room ID"),
}
}
newStateKey, err := synctypes.FromClientStateKey(*parsedRoomID, stateKey, func(roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
return rsAPI.QuerySenderIDForUser(ctx, roomID, userID)
})
if err != nil {
// TODO: work out better logic for failure cases (e.g. sender ID not found)
util.GetLogger(ctx).WithError(err).Error("synctypes.FromClientStateKey failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
stateKey = *newStateKey
}
// Always fetch visibility so that we can work out whether to show // Always fetch visibility so that we can work out whether to show
// the latest events or the last event from when the user was joined. // the latest events or the last event from when the user was joined.
// Then include the requested event type and state key, assuming it // Then include the requested event type and state key, assuming it
@ -301,7 +320,7 @@ func OnIncomingStateTypeRequest(
} }
// If the user has never been in the room then stop at this point. // If the user has never been in the room then stop at this point.
// We won't tell the user about a room they have never joined. // We won't tell the user about a room they have never joined.
if !membershipRes.HasBeenInRoom || membershipRes.Membership == spec.Ban { if (!membershipRes.HasBeenInRoom && membershipRes.Membership != spec.Invite) || membershipRes.Membership == spec.Ban {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), JSON: spec.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)),

View file

@ -366,9 +366,11 @@ func emit3PIDInviteEvent(
sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID)
if err != nil { if err != nil {
return err return err
} else if sender == nil {
return fmt.Errorf("sender ID not found for %s in %s", *userID, *validRoomID)
} }
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
SenderID: string(sender), SenderID: string(*sender),
RoomID: roomID, RoomID: roomID,
Type: "m.room.third_party_invite", Type: "m.room.third_party_invite",
StateKey: &res.Token, StateKey: &res.Token,

View file

@ -134,7 +134,7 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.MSCs = []string{"msc2836"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName)))
cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice"
@ -156,13 +157,14 @@ func main() {
keyRing := fsAPI.KeyRing() keyRing := fsAPI.KeyRing()
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient)
asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI)
// The underlying roomserver implementation needs to be able to call the fedsender. // The underlying roomserver implementation needs to be able to call the fedsender.
// This is different to rsAPI which can be the http client which doesn't need this // This is different to rsAPI which can be the http client which doesn't need this
// dependency. Other components also need updating after their dependencies are up. // dependency. Other components also need updating after their dependencies are up.
rsAPI.SetFederationAPI(fsAPI, keyRing) rsAPI.SetFederationAPI(fsAPI, keyRing)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient)
asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI)
rsAPI.SetAppserviceAPI(asAPI) rsAPI.SetAppserviceAPI(asAPI)
rsAPI.SetUserAPI(userAPI) rsAPI.SetUserAPI(userAPI)
@ -187,6 +189,16 @@ func main() {
} }
} }
upCounter := prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "dendrite",
Name: "up",
ConstLabels: map[string]string{
"version": internal.VersionString(),
},
})
upCounter.Add(1)
prometheus.MustRegister(upCounter)
// Expose the matrix APIs directly rather than putting them under a /api path. // Expose the matrix APIs directly rather than putting them under a /api path.
go func() { go func() {
basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpAddr, nil, nil) basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpAddr, nil, nil)

View file

@ -202,12 +202,25 @@ func main() {
authEvents[i] = authEventEntries[i].PDU authEvents[i] = authEventEntries[i].PDU
} }
// Get the roomNID
roomInfo, err = roomserverDB.RoomInfo(ctx, authEvents[0].RoomID().String())
if err != nil {
panic(err)
}
fmt.Println("Resolving state") fmt.Println("Resolving state")
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, rejectedErr := roomserverDB.IsEventRejected(ctx, roomInfo.RoomNID, eventID)
if rejectedErr != nil {
return true
}
return isRejected
},
) )
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -27,7 +27,6 @@ type FederationInternalAPI interface {
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, origin, dst spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res fclient.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res fclient.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.MSC2946SpacesResponse, err error)
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
PerformBroadcastEDU( PerformBroadcastEDU(
@ -63,6 +62,8 @@ type RoomserverFederationAPI interface {
PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse) error PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse) error
// Handle sending an invite to a remote server. // Handle sending an invite to a remote server.
SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error) SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
// Handle sending an invite to a remote server.
SendInviteV3(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
// Handle an instruction to peek a room on a remote server. // Handle an instruction to peek a room on a remote server.
PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error
// Query the server names of the joined hosts in a room. // Query the server names of the joined hosts in a room.
@ -73,6 +74,8 @@ type RoomserverFederationAPI interface {
GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error)
GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error)
RoomHierarchies(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.RoomHierarchyResponse, err error)
} }
type P2PFederationAPI interface { type P2PFederationAPI interface {

View file

@ -117,19 +117,27 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
return true return true
} }
var queryRes roomserverAPI.QueryRoomsForUserResponse userID, err := spec.NewUserID(m.UserID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: m.UserID, sentry.CaptureException(err)
WantMembership: "join", logger.WithError(err).Error("invalid user ID")
}, &queryRes) return true
}
roomIDs, err := t.rsAPI.QueryRoomsForUser(t.ctx, *userID, "join")
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined rooms for user") logger.WithError(err).Error("failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(roomIDs))
for i, room := range roomIDs {
roomIDStrs[i] = room.String()
}
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
@ -179,18 +187,27 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
} }
logger := logrus.WithField("user_id", output.UserID) logger := logrus.WithField("user_id", output.UserID)
var queryRes roomserverAPI.QueryRoomsForUserResponse outputUserID, err := spec.NewUserID(output.UserID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: output.UserID, sentry.CaptureException(err)
WantMembership: "join", logrus.WithError(err).Errorf("invalid user ID")
}, &queryRes) return true
}
rooms, err := t.rsAPI.QueryRoomsForUser(t.ctx, *outputUserID, "join")
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(rooms))
for i, room := range rooms {
roomIDStrs[i] = room.String()
}
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -94,16 +95,23 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
return true return true
} }
var queryRes roomserverAPI.QueryRoomsForUserResponse parsedUserID, err := spec.NewUserID(userID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: userID, util.GetLogger(ctx).WithError(err).WithField("user_id", userID).Error("invalid user ID")
WantMembership: "join", return true
}, &queryRes) }
roomIDs, err := t.rsAPI.QueryRoomsForUser(t.ctx, *parsedUserID, "join")
if err != nil { if err != nil {
log.WithError(err).Error("failed to calculate joined rooms for user") log.WithError(err).Error("failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(roomIDs))
for i, roomID := range roomIDs {
roomIDStrs[i] = roomID.String()
}
presence := msg.Header.Get("presence") presence := msg.Header.Get("presence")
ts, err := strconv.Atoi(msg.Header.Get("last_active_ts")) ts, err := strconv.Atoi(msg.Header.Get("last_active_ts"))
@ -112,7 +120,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
} }
// send this presence to all servers who share rooms with this user. // send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) joined, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
log.WithError(err).Error("failed to get joined hosts") log.WithError(err).Error("failed to get joined hosts")
return true return true

View file

@ -16,7 +16,9 @@ package consumers
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
@ -174,7 +176,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
// Finally, work out if there are any more events missing. // Finally, work out if there are any more events missing.
if len(missingEventIDs) > 0 { if len(missingEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: ore.Event.RoomID(), RoomID: ore.Event.RoomID().String(),
EventIDs: missingEventIDs, EventIDs: missingEventIDs,
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}
@ -203,7 +205,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
// talking to the roomserver // talking to the roomserver
oldJoinedHosts, err := s.db.UpdateRoom( oldJoinedHosts, err := s.db.UpdateRoom(
s.ctx, s.ctx,
ore.Event.RoomID(), ore.Event.RoomID().String(),
addsJoinedHosts, addsJoinedHosts,
ore.RemovesStateEventIDs, ore.RemovesStateEventIDs,
rewritesState, // if we're re-writing state, nuke all joined hosts before adding rewritesState, // if we're re-writing state, nuke all joined hosts before adding
@ -216,7 +218,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == spec.MRoomMember && ore.Event.StateKey() != nil { if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == spec.MRoomMember && ore.Event.StateKey() != nil {
membership, _ := ore.Event.Membership() membership, _ := ore.Event.Membership()
if membership == spec.Join { if membership == spec.Join {
s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) s.sendPresence(ore.Event.RoomID().String(), addsJoinedHosts)
} }
} }
@ -374,7 +376,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
} }
// handle peeking hosts // handle peeking hosts
inboundPeeks, err := s.db.GetInboundPeeks(s.ctx, ore.Event.PDU.RoomID()) inboundPeeks, err := s.db.GetInboundPeeks(s.ctx, ore.Event.PDU.RoomID().String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -407,17 +409,26 @@ func JoinedHostsFromEvents(ctx context.Context, evs []gomatrixserverlib.PDU, rsA
if membership != spec.Join { if membership != spec.Join {
continue continue
} }
validRoomID, err := spec.NewRoomID(ev.RoomID()) var domain spec.ServerName
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil { if err != nil {
return nil, err if errors.As(err, new(base64.CorruptInputError)) {
} // Fallback to using the "old" way of getting the user domain, avoids
userID, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey())) // "illegal base64 data at input byte 0" errors
if err != nil { // FIXME: we should do this in QueryUserIDForSender instead
return nil, err _, domain, err = gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return nil, err
}
} else {
return nil, err
}
} else {
domain = userID.Domain()
} }
joinedHosts = append(joinedHosts, types.JoinedHost{ joinedHosts = append(joinedHosts, types.JoinedHost{
MemberEventID: ev.EventID(), ServerName: userID.Domain(), MemberEventID: ev.EventID(), ServerName: domain,
}) })
} }
return joinedHosts, nil return joinedHosts, nil
@ -495,7 +506,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
// At this point the missing events are neither the event itself nor are // At this point the missing events are neither the event itself nor are
// they present in our local database. Our only option is to fetch them // they present in our local database. Our only option is to fetch them
// from the roomserver using the query API. // from the roomserver using the query API.
eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()} eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID().String()}
var eventResp api.QueryEventsByIDResponse var eventResp api.QueryEventsByIDResponse
if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil { if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil {
return nil, err return nil, err

View file

@ -95,7 +95,7 @@ func AddPublicRoutes(
func NewInternalAPI( func NewInternalAPI(
processContext *process.ProcessContext, processContext *process.ProcessContext,
dendriteCfg *config.Dendrite, dendriteCfg *config.Dendrite,
cm sqlutil.Connections, cm *sqlutil.Connections,
natsInstance *jetstream.NATSInstance, natsInstance *jetstream.NATSInstance,
federation fclient.FederationClient, federation fclient.FederationClient,
rsAPI roomserverAPI.FederationRoomserverAPI, rsAPI roomserverAPI.FederationRoomserverAPI,

View file

@ -33,15 +33,16 @@ import (
type fedRoomserverAPI struct { type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error queryRoomsForUser func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
} }
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
return spec.SenderID(userID.String()), nil senderID := spec.SenderID(userID.String())
return &senderID, nil
} }
// PerformJoin will call this function // PerformJoin will call this function
@ -53,11 +54,11 @@ func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.Input
} }
// keychange consumer calls this // keychange consumer calls this
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if f.queryRoomsForUser == nil { if f.queryRoomsForUser == nil {
return nil return nil, nil
} }
return f.queryRoomsForUser(ctx, req, res) return f.queryRoomsForUser(ctx, userID, desiredMembership)
} }
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate // TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
@ -145,7 +146,7 @@ func (f *fedClient) SendJoin(ctx context.Context, origin, s spec.ServerName, eve
f.fedClientMutex.Lock() f.fedClientMutex.Lock()
defer f.fedClientMutex.Unlock() defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == event.RoomID() { if r.ID == event.RoomID().String() {
r.InsertEvent(f.t, &types.HeaderedEvent{PDU: event}) r.InsertEvent(f.t, &types.HeaderedEvent{PDU: event})
f.t.Logf("Join event: %v", event.EventID()) f.t.Logf("Join event: %v", event.EventID())
res.StateEvents = types.NewEventJSONsFromHeaderedEvents(r.CurrentState()) res.StateEvents = types.NewEventJSONsFromHeaderedEvents(r.CurrentState())
@ -198,18 +199,22 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID) fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator) room := test.NewRoom(t, creator)
roomID, err := spec.NewRoomID(room.ID)
if err != nil {
t.Fatalf("Invalid room ID: %q", roomID)
}
rsapi := &fedRoomserverAPI{ rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if req.Asynchronous { if req.Asynchronous {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous") t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
} }
}, },
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { queryRoomsForUser: func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if req.UserID == joiningUser.ID && req.WantMembership == "join" { if userID.String() == joiningUser.ID && desiredMembership == "join" {
res.RoomIDs = []string{room.ID} return []spec.RoomID{*roomID}, nil
return nil
} }
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req) return nil, fmt.Errorf("unexpected queryRoomsForUser: %v, %v", userID, desiredMembership)
}, },
} }
fc := &fedClient{ fc := &fedClient{

View file

@ -29,7 +29,7 @@ func (a *FederationInternalAPI) MakeJoin(
func (a *FederationInternalAPI) SendJoin( func (a *FederationInternalAPI) SendJoin(
ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU,
) (res gomatrixserverlib.SendJoinResponse, err error) { ) (res gomatrixserverlib.SendJoinResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
defer cancel() defer cancel()
ires, err := a.federation.SendJoin(ctx, origin, s, event) ires, err := a.federation.SendJoin(ctx, origin, s, event)
if err != nil { if err != nil {
@ -194,16 +194,16 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
return ires.(fclient.MSC2836EventRelationshipsResponse), nil return ires.(fclient.MSC2836EventRelationshipsResponse), nil
} }
func (a *FederationInternalAPI) MSC2946Spaces( func (a *FederationInternalAPI) RoomHierarchies(
ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool,
) (res fclient.MSC2946SpacesResponse, err error) { ) (res fclient.RoomHierarchyResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly) return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly)
}) })
if err != nil { if err != nil {
return res, err return res, err
} }
return ires.(fclient.MSC2946SpacesResponse), nil return ires.(fclient.RoomHierarchyResponse), nil
} }

View file

@ -481,8 +481,10 @@ func (r *FederationInternalAPI) PerformLeave(
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
if err != nil { if err != nil {
return err return err
} else if senderID == nil {
return fmt.Errorf("sender ID not found for %s in %s", *userID, *roomID)
} }
senderIDString := string(senderID) senderIDString := string(*senderID)
respMakeLeave.LeaveEvent.Type = spec.MRoomMember respMakeLeave.LeaveEvent.Type = spec.MRoomMember
respMakeLeave.LeaveEvent.SenderID = senderIDString respMakeLeave.LeaveEvent.SenderID = senderIDString
respMakeLeave.LeaveEvent.StateKey = &senderIDString respMakeLeave.LeaveEvent.StateKey = &senderIDString
@ -546,11 +548,7 @@ func (r *FederationInternalAPI) SendInvite(
event gomatrixserverlib.PDU, event gomatrixserverlib.PDU,
strippedState []gomatrixserverlib.InviteStrippedState, strippedState []gomatrixserverlib.InviteStrippedState,
) (gomatrixserverlib.PDU, error) { ) (gomatrixserverlib.PDU, error) {
validRoomID, err := spec.NewRoomID(event.RoomID()) inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
return nil, err
}
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -573,7 +571,7 @@ func (r *FederationInternalAPI) SendInvite(
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
"user_id": *event.StateKey(), "user_id": *event.StateKey(),
"room_id": event.RoomID(), "room_id": event.RoomID().String(),
"room_version": event.Version(), "room_version": event.Version(),
"destination": destination, "destination": destination,
}).Info("Sending invite") }).Info("Sending invite")
@ -599,6 +597,58 @@ func (r *FederationInternalAPI) SendInvite(
return inviteEvent, nil return inviteEvent, nil
} }
// SendInviteV3 implements api.FederationInternalAPI
func (r *FederationInternalAPI) SendInviteV3(
ctx context.Context,
event gomatrixserverlib.ProtoEvent,
invitee spec.UserID,
version gomatrixserverlib.RoomVersion,
strippedState []gomatrixserverlib.InviteStrippedState,
) (gomatrixserverlib.PDU, error) {
validRoomID, err := spec.NewRoomID(event.RoomID)
if err != nil {
return nil, err
}
verImpl, err := gomatrixserverlib.GetRoomVersion(version)
if err != nil {
return nil, err
}
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(event.SenderID))
if err != nil {
return nil, err
}
// TODO (devon): This should be allowed via a relay. Currently only transactions
// can be sent to relays. Would need to extend relays to handle invites.
if !r.shouldAttemptDirectFederation(invitee.Domain()) {
return nil, fmt.Errorf("relay servers have no meaningful response for invite.")
}
logrus.WithFields(logrus.Fields{
"user_id": invitee.String(),
"room_id": event.RoomID,
"room_version": version,
"destination": invitee.Domain(),
}).Info("Sending invite")
inviteReq, err := fclient.NewInviteV3Request(event, version, strippedState)
if err != nil {
return nil, fmt.Errorf("gomatrixserverlib.NewInviteV3Request: %w", err)
}
inviteRes, err := r.federation.SendInviteV3(ctx, inviter.Domain(), invitee.Domain(), inviteReq, invitee)
if err != nil {
return nil, fmt.Errorf("r.federation.SendInviteV3: failed to send invite: %w", err)
}
inviteEvent, err := verImpl.NewEventFromUntrustedJSON(inviteRes.Event)
if err != nil {
return nil, fmt.Errorf("r.federation.SendInviteV3 failed to decode event response: %w", err)
}
return inviteEvent, nil
}
// PerformServersAlive implements api.FederationInternalAPI // PerformServersAlive implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformBroadcastEDU( func (r *FederationInternalAPI) PerformBroadcastEDU(
ctx context.Context, ctx context.Context,

View file

@ -218,7 +218,7 @@ func (oqs *OutgoingQueues) SendEvent(
if api.IsServerBannedFromRoom( if api.IsServerBannedFromRoom(
oqs.process.Context(), oqs.process.Context(),
oqs.rsAPI, oqs.rsAPI,
ev.RoomID(), ev.RoomID().String(),
destination, destination,
) { ) {
delete(destmap, destination) delete(destmap, destination)

View file

@ -109,7 +109,7 @@ func Backfill(
var ev *types.HeaderedEvent var ev *types.HeaderedEvent
for _, ev = range res.Events { for _, ev = range res.Events {
if ev.RoomID() == roomID { if ev.RoomID().String() == roomID {
evs = append(evs, ev.PDU) evs = append(evs, ev.PDU)
} }
} }

View file

@ -42,10 +42,10 @@ func GetEventAuth(
return *resErr return *resErr
} }
if event.RoomID() != roomID { if event.RoomID().String() != roomID {
return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
} }
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID().String())
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -42,7 +42,7 @@ func GetEvent(
return *err return *err
} }
err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID().String())
if err != nil { if err != nil {
return *err return *err
} }

View file

@ -99,7 +99,7 @@ func MakeJoin(
Roomserver: rsAPI, Roomserver: rsAPI,
} }
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) senderIDPtr, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{ return util.JSONResponse{
@ -108,8 +108,11 @@ func MakeJoin(
} }
} }
if senderID == "" { var senderID spec.SenderID
if senderIDPtr == nil {
senderID = spec.SenderID(userID.String()) senderID = spec.SenderID(userID.String())
} else {
senderID = *senderIDPtr
} }
input := gomatrixserverlib.HandleMakeJoinInput{ input := gomatrixserverlib.HandleMakeJoinInput{
@ -187,9 +190,6 @@ func MakeJoin(
} }
// SendJoin implements the /send_join API // SendJoin implements the /send_join API
// The make-join send-join dance makes much more sense as a single
// flow so the cyclomatic complexity is high:
// nolint:gocyclo
func SendJoin( func SendJoin(
httpReq *http.Request, httpReq *http.Request,
request *fclient.FederationRequest, request *fclient.FederationRequest,

View file

@ -94,11 +94,17 @@ func MakeLeave(
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} else if senderID == nil {
util.GetLogger(httpReq.Context()).WithField("roomID", roomID).WithField("userID", userID).Error("rsAPI.QuerySenderIDForUser returned nil sender ID")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
} }
input := gomatrixserverlib.HandleMakeLeaveInput{ input := gomatrixserverlib.HandleMakeLeaveInput{
UserID: userID, UserID: userID,
SenderID: senderID, SenderID: *senderID,
RoomID: roomID, RoomID: roomID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RequestOrigin: request.Origin(), RequestOrigin: request.Origin(),
@ -205,7 +211,7 @@ func SendLeave(
} }
// Check that the room ID is correct. // Check that the room ID is correct.
if event.RoomID() != roomID { if event.RoomID().String() != roomID {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON"), JSON: spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON"),
@ -236,14 +242,7 @@ func SendLeave(
// Check that the sender belongs to the server that is sending us // Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender // the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both. // and the state key are equal so we don't need to check both.
validRoomID, err := spec.NewRoomID(event.RoomID()) sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Room ID is invalid."),
}
}
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID())
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View file

@ -87,7 +87,7 @@ func filterEvents(
) []*types.HeaderedEvent { ) []*types.HeaderedEvent {
ref := events[:0] ref := events[:0]
for _, ev := range events { for _, ev := range events {
if ev.RoomID() == roomID { if ev.RoomID().String() == roomID {
ref = append(ref, ev) ref = append(ref, ev)
} }
} }

View file

@ -113,10 +113,10 @@ func getState(
return nil, nil, resErr return nil, nil, resErr
} }
if event.RoomID() != roomID { if event.RoomID().String() != roomID {
return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
} }
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID().String())
if resErr != nil { if resErr != nil {
return nil, nil, resErr return nil, nil, resErr
} }

View file

@ -36,7 +36,7 @@ type Database struct {
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) { func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) {
var d Database var d Database
var err error var err error
if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil {

View file

@ -34,7 +34,7 @@ type Database struct {
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) { func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) {
var d Database var d Database
var err error var err error
if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil {

View file

@ -30,7 +30,7 @@ import (
) )
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) { func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName)

View file

@ -22,11 +22,11 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec"
) )
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) { func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName)

8
go.mod
View file

@ -19,7 +19,7 @@ require (
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20231023121512-16e7431168be github.com/matrix-org/gomatrixserverlib v0.0.0-20231024124730-58af9a2712ca
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.17 github.com/mattn/go-sqlite3 v1.14.17
github.com/nats-io/nats-server/v2 v2.9.23 github.com/nats-io/nats-server/v2 v2.9.23
@ -115,12 +115,12 @@ require (
github.com/prometheus/common v0.42.0 // indirect github.com/prometheus/common v0.42.0 // indirect
github.com/prometheus/procfs v0.10.1 // indirect github.com/prometheus/procfs v0.10.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/rs/zerolog v1.29.1 // indirect github.com/rs/zerolog v1.29.1 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
go.etcd.io/bbolt v1.3.6 // indirect go.etcd.io/bbolt v1.3.5 // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.9.0 // indirect
golang.org/x/net v0.17.0 // indirect golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect

17
go.sum
View file

@ -184,8 +184,8 @@ github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBE
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20231023121512-16e7431168be h1:bZP16ydP8uRoRBo1p/7WHMexjg7JJGj81fKzZ1FULb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20231024124730-58af9a2712ca h1:JCP72vU4Vcmur2071RwYVOSoekR+ZjbC03wZD5lAAK0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20231023121512-16e7431168be/go.mod h1:M8m7seOroO5ePlgxA7AFZymnG90Cnh94rYQyngSrZkk= github.com/matrix-org/gomatrixserverlib v0.0.0-20231024124730-58af9a2712ca/go.mod h1:M8m7seOroO5ePlgxA7AFZymnG90Cnh94rYQyngSrZkk=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4=
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
@ -257,8 +257,8 @@ github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPH
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc=
github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU=
@ -303,8 +303,8 @@ github.com/yggdrasil-network/yggdrasil-go v0.4.6/go.mod h1:PBMoAOvQjA9geNEeGyMXA
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@ -336,8 +336,9 @@ golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -362,7 +363,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View file

@ -8,6 +8,7 @@ type RoomServerCaches interface {
RoomServerNIDsCache RoomServerNIDsCache
RoomVersionCache RoomVersionCache
RoomServerEventsCache RoomServerEventsCache
RoomHierarchyCache
EventStateKeyCache EventStateKeyCache
EventTypeCache EventTypeCache
} }

View file

@ -2,15 +2,16 @@ package caching
import "github.com/matrix-org/gomatrixserverlib/fclient" import "github.com/matrix-org/gomatrixserverlib/fclient"
type SpaceSummaryRoomsCache interface { // RoomHierarchy cache caches responses to federated room hierarchy requests (A.K.A. 'space summaries')
GetSpaceSummary(roomID string) (r fclient.MSC2946SpacesResponse, ok bool) type RoomHierarchyCache interface {
StoreSpaceSummary(roomID string, r fclient.MSC2946SpacesResponse) GetRoomHierarchy(roomID string) (r fclient.RoomHierarchyResponse, ok bool)
StoreRoomHierarchy(roomID string, r fclient.RoomHierarchyResponse)
} }
func (c Caches) GetSpaceSummary(roomID string) (r fclient.MSC2946SpacesResponse, ok bool) { func (c Caches) GetRoomHierarchy(roomID string) (r fclient.RoomHierarchyResponse, ok bool) {
return c.SpaceSummaryRooms.Get(roomID) return c.RoomHierarchies.Get(roomID)
} }
func (c Caches) StoreSpaceSummary(roomID string, r fclient.MSC2946SpacesResponse) { func (c Caches) StoreRoomHierarchy(roomID string, r fclient.RoomHierarchyResponse) {
c.SpaceSummaryRooms.Set(roomID, r) c.RoomHierarchies.Set(roomID, r)
} }

View file

@ -35,7 +35,7 @@ type Caches struct {
RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType
FederationPDUs Cache[int64, *types.HeaderedEvent] // queue NID -> PDU FederationPDUs Cache[int64, *types.HeaderedEvent] // queue NID -> PDU
FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU
SpaceSummaryRooms Cache[string, fclient.MSC2946SpacesResponse] // room ID -> space response RoomHierarchies Cache[string, fclient.RoomHierarchyResponse] // room ID -> space response
LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID
} }

View file

@ -147,7 +147,7 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm
MaxAge: lesserOf(time.Hour/2, maxAge), MaxAge: lesserOf(time.Hour/2, maxAge),
}, },
}, },
SpaceSummaryRooms: &RistrettoCachePartition[string, fclient.MSC2946SpacesResponse]{ // room ID -> space response RoomHierarchies: &RistrettoCachePartition[string, fclient.RoomHierarchyResponse]{ // room ID -> space response
cache: cache, cache: cache,
Prefix: spaceSummaryRoomsCache, Prefix: spaceSummaryRoomsCache,
Mutable: true, Mutable: true,

View file

@ -176,15 +176,13 @@ func RedactEvent(ctx context.Context, redactionEvent, redactedEvent gomatrixserv
return fmt.Errorf("RedactEvent: redactionEvent isn't a redaction event, is '%s'", redactionEvent.Type()) return fmt.Errorf("RedactEvent: redactionEvent isn't a redaction event, is '%s'", redactionEvent.Type())
} }
redactedEvent.Redact() redactedEvent.Redact()
validRoomID, err := spec.NewRoomID(redactionEvent.RoomID()) clientEvent, err := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return querier.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil { if err != nil {
return err return err
} }
senderID, err := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) redactedBecause := clientEvent
if err != nil {
return err
}
redactedBecause := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, *senderID, redactionEvent.StateKey())
if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil { if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil {
return err return err
} }

View file

@ -111,15 +111,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
return patternMatches("content.body", *rule.Pattern, event) return patternMatches("content.body", *rule.Pattern, event)
case RoomKind: case RoomKind:
return rule.RuleID == event.RoomID(), nil return rule.RuleID == event.RoomID().String(), nil
case SenderKind: case SenderKind:
userID := "" userID := ""
validRoomID, err := spec.NewRoomID(event.RoomID()) sender, err := userIDForSender(event.RoomID(), event.SenderID())
if err != nil {
return false, err
}
sender, err := userIDForSender(*validRoomID, event.SenderID())
if err == nil { if err == nil {
userID = sender.String() userID = sender.String()
} }

View file

@ -17,59 +17,70 @@ package sqlutil
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"sync"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
) )
type Connections struct { type Connections struct {
db *sql.DB globalConfig config.DatabaseOptions
writer Writer processContext *process.ProcessContext
globalConfig config.DatabaseOptions existingConnections sync.Map
processContext *process.ProcessContext
} }
func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) Connections { type con struct {
return Connections{ db *sql.DB
writer Writer
}
func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) *Connections {
return &Connections{
globalConfig: globalConfig, globalConfig: globalConfig,
processContext: processCtx, processContext: processCtx,
} }
} }
func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, Writer, error) { func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, Writer, error) {
var err error
// If no connectionString was provided, try the global one
if dbProperties.ConnectionString == "" {
dbProperties = &c.globalConfig
// If we still don't have a connection string, that's a problem
if dbProperties.ConnectionString == "" {
return nil, nil, fmt.Errorf("no database connections configured")
}
}
writer := NewDummyWriter() writer := NewDummyWriter()
if dbProperties.ConnectionString.IsSQLite() { if dbProperties.ConnectionString.IsSQLite() {
writer = NewExclusiveWriter() writer = NewExclusiveWriter()
} }
var err error
if dbProperties.ConnectionString == "" { existing, loaded := c.existingConnections.LoadOrStore(dbProperties.ConnectionString, &con{})
// if no connectionString was provided, try the global one if loaded {
dbProperties = &c.globalConfig // We found an existing connection
ex := existing.(*con)
return ex.db, ex.writer, nil
} }
if dbProperties.ConnectionString != "" || c.db == nil {
// Open a new database connection using the supplied config. // Open a new database connection using the supplied config.
c.db, err = Open(dbProperties, writer) db, err := Open(dbProperties, writer)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
}
c.existingConnections.Store(dbProperties.ConnectionString, &con{db: db, writer: writer})
go func() {
if c.processContext == nil {
return
} }
c.writer = writer // If we have a ProcessContext, start a component and wait for
go func() { // Dendrite to shut down to cleanly close the database connection.
if c.processContext == nil { c.processContext.ComponentStarted()
return <-c.processContext.WaitForShutdown()
} _ = db.Close()
// If we have a ProcessContext, start a component and wait for c.processContext.ComponentFinished()
// Dendrite to shut down to cleanly close the database connection. }()
c.processContext.ComponentStarted() return db, writer, nil
<-c.processContext.WaitForShutdown()
_ = c.db.Close()
c.processContext.ComponentFinished()
}()
return c.db, c.writer, nil
}
if c.db != nil && c.writer != nil {
// Ignore the supplied config and return the global pool and
// writer.
return c.db, c.writer, nil
}
return nil, nil, fmt.Errorf("no database connections configured")
} }

View file

@ -6,51 +6,135 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
) )
func TestConnectionManager(t *testing.T) { func TestConnectionManager(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Cleanup(close)
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
dbProps := &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)} t.Run("component defined connection string", func(t *testing.T) {
db, writer, err := cm.Connection(dbProps) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
if err != nil { conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Fatal(err) t.Cleanup(close)
} cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
switch dbType { dbProps := &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}
case test.DBTypeSQLite: db, writer, err := cm.Connection(dbProps)
_, ok := writer.(*sqlutil.ExclusiveWriter) if err != nil {
if !ok { t.Fatal(err)
t.Fatalf("expected exclusive writer")
} }
case test.DBTypePostgres:
_, ok := writer.(*sqlutil.DummyWriter) switch dbType {
if !ok { case test.DBTypeSQLite:
t.Fatalf("expected dummy writer") _, ok := writer.(*sqlutil.ExclusiveWriter)
if !ok {
t.Fatalf("expected exclusive writer")
}
case test.DBTypePostgres:
_, ok := writer.(*sqlutil.DummyWriter)
if !ok {
t.Fatalf("expected dummy writer")
}
} }
}
// test global db pool // reuse existing connection
dbGlobal, writerGlobal, err := cm.Connection(&config.DatabaseOptions{}) db2, writer2, err := cm.Connection(dbProps)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(db, dbGlobal) { if !reflect.DeepEqual(db, db2) {
t.Fatalf("expected database connection to be reused") t.Fatalf("expected database connection to be reused")
} }
if !reflect.DeepEqual(writer, writerGlobal) { if !reflect.DeepEqual(writer, writer2) {
t.Fatalf("expected database writer to be reused") t.Fatalf("expected database writer to be reused")
} }
// test invalid connection string configured // This test does not work with Postgres, because we can't just simply append
cm2 := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) // "x" or replace the database to use.
_, _, err = cm2.Connection(&config.DatabaseOptions{ConnectionString: "http://"}) if dbType == test.DBTypePostgres {
if err == nil { return
t.Fatal("expected an error but got none") }
}
// Test different connection string
dbProps = &config.DatabaseOptions{ConnectionString: config.DataSource(conStr + "x")}
db3, _, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(db, db3) {
t.Fatalf("expected different database connection")
}
})
}) })
t.Run("global connection pool", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Cleanup(close)
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{ConnectionString: config.DataSource(conStr)})
dbProps := &config.DatabaseOptions{}
db, writer, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}
switch dbType {
case test.DBTypeSQLite:
_, ok := writer.(*sqlutil.ExclusiveWriter)
if !ok {
t.Fatalf("expected exclusive writer")
}
case test.DBTypePostgres:
_, ok := writer.(*sqlutil.DummyWriter)
if !ok {
t.Fatalf("expected dummy writer")
}
}
// reuse existing connection
db2, writer2, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(db, db2) {
t.Fatalf("expected database connection to be reused")
}
if !reflect.DeepEqual(writer, writer2) {
t.Fatalf("expected database writer to be reused")
}
})
})
t.Run("shutdown", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Cleanup(close)
processCtx := process.NewProcessContext()
cm := sqlutil.NewConnectionManager(processCtx, config.DatabaseOptions{ConnectionString: config.DataSource(conStr)})
dbProps := &config.DatabaseOptions{}
_, _, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}
processCtx.ShutdownDendrite()
processCtx.WaitForComponentsToFinish()
})
})
// test invalid connection string configured
cm2 := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
_, _, err := cm2.Connection(&config.DatabaseOptions{ConnectionString: "http://"})
if err == nil {
t.Fatal("expected an error but got none")
}
// empty connection string is not allowed
_, _, err = cm2.Connection(&config.DatabaseOptions{})
if err == nil {
t.Fatal("expected an error but got none")
}
} }

View file

@ -161,7 +161,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
if event.Type() == spec.MRoomCreate && event.StateKeyEquals("") { if event.Type() == spec.MRoomCreate && event.StateKeyEquals("") {
continue continue
} }
if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID().String(), t.Origin) {
results[event.EventID()] = fclient.PDUResult{ results[event.EventID()] = fclient.PDUResult{
Error: "Forbidden by server ACLs", Error: "Forbidden by server ACLs",
} }

View file

@ -28,7 +28,7 @@ import (
// AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component. // AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component.
func AddPublicRoutes( func AddPublicRoutes(
mediaRouter *mux.Router, mediaRouter *mux.Router,
cm sqlutil.Connections, cm *sqlutil.Connections,
cfg *config.Dendrite, cfg *config.Dendrite,
userAPI userapi.MediaUserAPI, userAPI userapi.MediaUserAPI,
client *fclient.Client, client *fclient.Client,

View file

@ -24,7 +24,7 @@ import (
) )
// NewDatabase opens a postgres database. // NewDatabase opens a postgres database.
func NewDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) { func NewDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, writer, err := conMan.Connection(dbProperties) db, writer, err := conMan.Connection(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -23,7 +23,7 @@ import (
) )
// NewDatabase opens a SQLIte database. // NewDatabase opens a SQLIte database.
func NewDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) { func NewDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, writer, err := conMan.Connection(dbProperties) db, writer, err := conMan.Connection(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -27,7 +27,7 @@ import (
) )
// NewMediaAPIDatasource opens a database connection. // NewMediaAPIDatasource opens a database connection.
func NewMediaAPIDatasource(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) { func NewMediaAPIDatasource(conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(conMan, dbProperties) return sqlite3.NewDatabase(conMan, dbProperties)

View file

@ -23,7 +23,7 @@ import (
) )
// Open opens a postgres database. // Open opens a postgres database.
func NewMediaAPIDatasource(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) { func NewMediaAPIDatasource(conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(conMan, dbProperties) return sqlite3.NewDatabase(conMan, dbProperties)

View file

@ -25,7 +25,7 @@ import (
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase( func NewDatabase(
conMan sqlutil.Connections, conMan *sqlutil.Connections,
dbProperties *config.DatabaseOptions, dbProperties *config.DatabaseOptions,
cache caching.FederationCache, cache caching.FederationCache,
isLocalServerName func(spec.ServerName) bool, isLocalServerName func(spec.ServerName) bool,

View file

@ -119,7 +119,7 @@ func (s *ServerACLs) OnServerACLUpdate(state gomatrixserverlib.PDU) {
}).Debugf("Updating server ACLs for %q", state.RoomID()) }).Debugf("Updating server ACLs for %q", state.RoomID())
s.aclsMutex.Lock() s.aclsMutex.Lock()
defer s.aclsMutex.Unlock() defer s.aclsMutex.Unlock()
s.acls[state.RoomID()] = acls s.acls[state.RoomID().String()] = acls
} }
func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID string) bool { func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID string) bool {

View file

@ -16,26 +16,8 @@ package api
import ( import (
"regexp" "regexp"
"github.com/matrix-org/gomatrixserverlib/spec"
) )
// SetRoomAliasRequest is a request to SetRoomAlias
type SetRoomAliasRequest struct {
// ID of the user setting the alias
UserID string `json:"user_id"`
// New alias for the room
Alias string `json:"alias"`
// The room ID the alias is referring to
RoomID string `json:"room_id"`
}
// SetRoomAliasResponse is a response to SetRoomAlias
type SetRoomAliasResponse struct {
// Does the alias already refer to a room?
AliasExists bool `json:"alias_exists"`
}
// GetRoomIDForAliasRequest is a request to GetRoomIDForAlias // GetRoomIDForAliasRequest is a request to GetRoomIDForAlias
type GetRoomIDForAliasRequest struct { type GetRoomIDForAliasRequest struct {
// Alias we want to lookup // Alias we want to lookup
@ -63,22 +45,6 @@ type GetAliasesForRoomIDResponse struct {
Aliases []string `json:"aliases"` Aliases []string `json:"aliases"`
} }
// RemoveRoomAliasRequest is a request to RemoveRoomAlias
type RemoveRoomAliasRequest struct {
// ID of the user removing the alias
SenderID spec.SenderID `json:"user_id"`
// The room alias to remove
Alias string `json:"alias"`
}
// RemoveRoomAliasResponse is a response to RemoveRoomAlias
type RemoveRoomAliasResponse struct {
// Did the alias exist before?
Found bool `json:"found"`
// Did we remove it?
Removed bool `json:"removed"`
}
type AliasEvent struct { type AliasEvent struct {
Alias string `json:"alias"` Alias string `json:"alias"`
AltAliases []string `json:"alt_aliases"` AltAliases []string `json:"alt_aliases"`

View file

@ -34,6 +34,17 @@ func (e ErrNotAllowed) Error() string {
return e.Err.Error() return e.Err.Error()
} }
// ErrRoomUnknownOrNotAllowed is an error return if either the provided
// room ID does not exist, or points to a room that the requester does
// not have access to.
type ErrRoomUnknownOrNotAllowed struct {
Err error
}
func (e ErrRoomUnknownOrNotAllowed) Error() string {
return e.Err.Error()
}
type RestrictedJoinAPI interface { type RestrictedJoinAPI interface {
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error)
@ -44,6 +55,11 @@ type RestrictedJoinAPI interface {
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
} }
type DefaultRoomVersionAPI interface {
// Returns the default room version used.
DefaultRoomVersion() gomatrixserverlib.RoomVersion
}
// RoomserverInputAPI is used to write events to the room server. // RoomserverInputAPI is used to write events to the room server.
type RoomserverInternalAPI interface { type RoomserverInternalAPI interface {
SyncRoomserverAPI SyncRoomserverAPI
@ -53,6 +69,7 @@ type RoomserverInternalAPI interface {
FederationRoomserverAPI FederationRoomserverAPI
QuerySenderIDAPI QuerySenderIDAPI
UserRoomPrivateKeyCreator UserRoomPrivateKeyCreator
DefaultRoomVersionAPI
// needed to avoid chicken and egg scenario when setting up the // needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs // interdependencies between the roomserver and other input APIs
@ -86,7 +103,7 @@ type InputRoomEventsAPI interface {
} }
type QuerySenderIDAPI interface { type QuerySenderIDAPI interface {
QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error)
QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error)
} }
@ -113,11 +130,39 @@ type QueryEventsAPI interface {
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
} }
type QueryRoomHierarchyAPI interface {
// Traverse the room hierarchy using the provided walker up to the provided limit,
// returning a new walker which can be used to fetch the next page.
//
// If limit is -1, this is treated as no limit, and the entire hierarchy will be traversed.
//
// If returned walker is nil, then there are no more rooms left to traverse. This method does not modify the provided walker, so it
// can be cached.
QueryNextRoomHierarchyPage(ctx context.Context, walker RoomHierarchyWalker, limit int) ([]fclient.RoomHierarchyRoom, *RoomHierarchyWalker, error)
}
type QueryMembershipAPI interface {
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to *types.HeaderedEvent of membership events.
QueryMembershipAtEvent(
ctx context.Context,
roomID spec.RoomID,
eventIDs []string,
senderID spec.SenderID,
) (map[string]*types.HeaderedEvent, error)
}
// API functions required by the syncapi // API functions required by the syncapi
type SyncRoomserverAPI interface { type SyncRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
QueryMembershipAPI
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
@ -127,12 +172,6 @@ type SyncRoomserverAPI interface {
req *QueryEventsByIDRequest, req *QueryEventsByIDRequest,
res *QueryEventsByIDResponse, res *QueryEventsByIDResponse,
) error ) error
// Query the membership event for an user for a room.
QueryMembershipForUser(
ctx context.Context,
req *QueryMembershipForUserRequest,
res *QueryMembershipForUserResponse,
) error
// Query the state after a list of events in a room from the room server. // Query the state after a list of events in a room from the room server.
QueryStateAfterEvents( QueryStateAfterEvents(
@ -147,14 +186,6 @@ type SyncRoomserverAPI interface {
req *PerformBackfillRequest, req *PerformBackfillRequest,
res *PerformBackfillResponse, res *PerformBackfillResponse,
) error ) error
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to a slice of types.HeaderedEvent.
QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error
} }
type AppserviceRoomserverAPI interface { type AppserviceRoomserverAPI interface {
@ -187,9 +218,11 @@ type ClientRoomserverAPI interface {
QueryEventsAPI QueryEventsAPI
QuerySenderIDAPI QuerySenderIDAPI
UserRoomPrivateKeyCreator UserRoomPrivateKeyCreator
QueryRoomHierarchyAPI
DefaultRoomVersionAPI
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms. // QueryKnownUsers returns a list of users that we know about from our joined rooms.
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
@ -214,8 +247,19 @@ type ClientRoomserverAPI interface {
PerformPublish(ctx context.Context, req *PerformPublishRequest) error PerformPublish(ctx context.Context, req *PerformPublishRequest) error
// PerformForget forgets a rooms history for a specific user // PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error
RemoveRoomAlias(ctx context.Context, req *RemoveRoomAliasRequest, res *RemoveRoomAliasResponse) error // Sets a room alias, as provided sender, pointing to the provided room ID.
//
// If err is nil, then the returned boolean indicates if the alias is already in use.
// If true, then the alias has not been set to the provided room, as it already in use.
SetRoomAlias(ctx context.Context, senderID spec.SenderID, roomID spec.RoomID, alias string) (aliasAlreadyExists bool, err error)
//RemoveRoomAlias(ctx context.Context, req *RemoveRoomAliasRequest, res *RemoveRoomAliasResponse) error
// Removes a room alias, as provided sender.
//
// Returns whether the alias was found, whether it was removed, and an error (if any occurred)
RemoveRoomAlias(ctx context.Context, senderID spec.SenderID, alias string) (aliasFound bool, aliasRemoved bool, err error)
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
} }
@ -227,6 +271,7 @@ type UserRoomserverAPI interface {
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
JoinedUserCount(ctx context.Context, roomID string) (int, error)
} }
type FederationRoomserverAPI interface { type FederationRoomserverAPI interface {
@ -235,15 +280,13 @@ type FederationRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
QueryRoomHierarchyAPI
QueryMembershipAPI
UserRoomPrivateKeyCreator UserRoomPrivateKeyCreator
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID. // which room to use by querying the first events roomID.
@ -257,7 +300,7 @@ type FederationRoomserverAPI interface {
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error) QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error

View file

@ -50,9 +50,21 @@ type PerformLeaveResponse struct {
Message interface{} `json:"message,omitempty"` Message interface{} `json:"message,omitempty"`
} }
type InviteInput struct {
RoomID spec.RoomID
Inviter spec.UserID
Invitee spec.UserID
DisplayName string
AvatarURL string
Reason string
IsDirect bool
KeyID gomatrixserverlib.KeyID
PrivateKey ed25519.PrivateKey
EventTime time.Time
}
type PerformInviteRequest struct { type PerformInviteRequest struct {
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` InviteInput InviteInput
Event *types.HeaderedEvent `json:"event"`
InviteRoomState []gomatrixserverlib.InviteStrippedState `json:"invite_room_state"` InviteRoomState []gomatrixserverlib.InviteStrippedState `json:"invite_room_state"`
SendAsServer string `json:"send_as_server"` SendAsServer string `json:"send_as_server"`
TransactionID *TransactionID `json:"transaction_id"` TransactionID *TransactionID `json:"transaction_id"`

View file

@ -132,6 +132,8 @@ type QueryMembershipForUserResponse struct {
// True if the user asked to forget this room. // True if the user asked to forget this room.
IsRoomForgotten bool `json:"is_room_forgotten"` IsRoomForgotten bool `json:"is_room_forgotten"`
RoomExists bool `json:"room_exists"` RoomExists bool `json:"room_exists"`
// The sender ID of the user in the room, if it exists
SenderID *spec.SenderID
} }
// QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom // QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom
@ -289,16 +291,6 @@ type QuerySharedUsersResponse struct {
UserIDsToCount map[string]int UserIDsToCount map[string]int
} }
type QueryRoomsForUserRequest struct {
UserID string
// The desired membership of the user. If this is the empty string then no rooms are returned.
WantMembership string
}
type QueryRoomsForUserResponse struct {
RoomIDs []string
}
type QueryBulkStateContentRequest struct { type QueryBulkStateContentRequest struct {
// Returns state events in these rooms // Returns state events in these rooms
RoomIDs []string RoomIDs []string
@ -414,22 +406,6 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
return nil return nil
} }
// QueryMembershipAtEventRequest requests the membership event for a user
// for a list of eventIDs.
type QueryMembershipAtEventRequest struct {
RoomID string
EventIDs []string
UserID string
}
// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
type QueryMembershipAtEventResponse struct {
// Membership is a map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility.
Membership map[string]*types.HeaderedEvent `json:"membership"`
}
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a // QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a
// a room with anymore. This is used to cleanup stale device list entries, where we would // a room with anymore. This is used to cleanup stale device list entries, where we would
// otherwise keep on trying to get device lists. // otherwise keep on trying to get device lists.
@ -503,3 +479,79 @@ func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.
} }
return membership, err return membership, err
} }
type QueryRoomHierarchyRequest struct {
SuggestedOnly bool `json:"suggested_only"`
Limit int `json:"limit"`
MaxDepth int `json:"max_depth"`
From int `json:"json"`
}
// A struct storing the intermediate state of a room hierarchy query for pagination purposes.
//
// Used for implementing space summaries / room hierarchies
//
// Use NewRoomHierarchyWalker to construct this, and QueryNextRoomHierarchyPage on the roomserver API
// to traverse the room hierarchy.
type RoomHierarchyWalker struct {
RootRoomID spec.RoomID
Caller types.DeviceOrServerName
SuggestedOnly bool
MaxDepth int
Processed RoomSet
Unvisited []RoomHierarchyWalkerQueuedRoom
}
type RoomHierarchyWalkerQueuedRoom struct {
RoomID spec.RoomID
ParentRoomID *spec.RoomID
Depth int
Vias []string // vias to query this room by
}
// Create a new room hierarchy walker, starting from the provided root room ID.
//
// Use the resulting struct with QueryNextRoomHierarchyPage on the roomserver API to traverse the room hierarchy.
func NewRoomHierarchyWalker(caller types.DeviceOrServerName, roomID spec.RoomID, suggestedOnly bool, maxDepth int) RoomHierarchyWalker {
walker := RoomHierarchyWalker{
RootRoomID: roomID,
Caller: caller,
SuggestedOnly: suggestedOnly,
MaxDepth: maxDepth,
Unvisited: []RoomHierarchyWalkerQueuedRoom{{
RoomID: roomID,
ParentRoomID: nil,
Depth: 0,
}},
Processed: NewRoomSet(),
}
return walker
}
// A set of room IDs.
type RoomSet map[spec.RoomID]struct{}
// Create a new empty room set.
func NewRoomSet() RoomSet {
return RoomSet{}
}
// Check if a room ID is in a room set.
func (s RoomSet) Contains(val spec.RoomID) bool {
_, ok := s[val]
return ok
}
// Add a room ID to a room set.
func (s RoomSet) Add(val spec.RoomID) {
s[val] = struct{}{}
}
func (s RoomSet) Copy() RoomSet {
copied := make(RoomSet, len(s))
for k := range s {
copied.Add(k)
}
return copied
}

View file

@ -75,7 +75,7 @@ func SendEventWithState(
} }
logrus.WithContext(ctx).WithFields(logrus.Fields{ logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": event.RoomID(), "room_id": event.RoomID().String(),
"event_id": event.EventID(), "event_id": event.EventID(),
"outliers": len(ires), "outliers": len(ires),
"state_ids": len(stateEventIDs), "state_ids": len(stateEventIDs),

View file

@ -85,11 +85,7 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySende
continue continue
} }
validRoomID, err := spec.NewRoomID(ev.RoomID()) userID, err := querier.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
if err != nil {
continue
}
userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
if err != nil { if err != nil {
continue continue
} }

View file

@ -35,27 +35,27 @@ import (
// SetRoomAlias implements alias.RoomserverInternalAPI // SetRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) SetRoomAlias( func (r *RoomserverInternalAPI) SetRoomAlias(
ctx context.Context, ctx context.Context,
request *api.SetRoomAliasRequest, senderID spec.SenderID,
response *api.SetRoomAliasResponse, roomID spec.RoomID,
) error { alias string,
) (aliasAlreadyUsed bool, err error) {
// Check if the alias isn't already referring to a room // Check if the alias isn't already referring to a room
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) existingRoomID, err := r.DB.GetRoomIDForAlias(ctx, alias)
if err != nil { if err != nil {
return err return false, err
} }
if len(roomID) > 0 {
if len(existingRoomID) > 0 {
// If the alias already exists, stop the process // If the alias already exists, stop the process
response.AliasExists = true return true, nil
return nil
} }
response.AliasExists = false
// Save the new alias // Save the new alias
if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.UserID); err != nil { if err := r.DB.SetRoomAlias(ctx, alias, roomID.String(), string(senderID)); err != nil {
return err return false, err
} }
return nil return false, nil
} }
// GetRoomIDForAlias implements alias.RoomserverInternalAPI // GetRoomIDForAlias implements alias.RoomserverInternalAPI
@ -116,91 +116,80 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID(
// nolint:gocyclo // nolint:gocyclo
// RemoveRoomAlias implements alias.RoomserverInternalAPI // RemoveRoomAlias implements alias.RoomserverInternalAPI
// nolint: gocyclo // nolint: gocyclo
func (r *RoomserverInternalAPI) RemoveRoomAlias( func (r *RoomserverInternalAPI) RemoveRoomAlias(ctx context.Context, senderID spec.SenderID, alias string) (aliasFound bool, aliasRemoved bool, err error) {
ctx context.Context, roomID, err := r.DB.GetRoomIDForAlias(ctx, alias)
request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse,
) error {
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) return false, false, fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
} }
if roomID == "" { if roomID == "" {
response.Found = false return false, false, nil
response.Removed = false
return nil
} }
validRoomID, err := spec.NewRoomID(roomID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
return err return true, false, err
} }
sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID) sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
if err != nil || sender == nil { if err != nil || sender == nil {
return fmt.Errorf("r.QueryUserIDForSender: %w", err) return true, false, fmt.Errorf("r.QueryUserIDForSender: %w", err)
} }
virtualHost := sender.Domain() virtualHost := sender.Domain()
response.Found = true creatorID, err := r.DB.GetCreatorIDForAlias(ctx, alias)
creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) return true, false, fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err)
} }
if spec.SenderID(creatorID) != request.SenderID { if spec.SenderID(creatorID) != senderID {
var plEvent *types.HeaderedEvent var plEvent *types.HeaderedEvent
var pls *gomatrixserverlib.PowerLevelContent var pls *gomatrixserverlib.PowerLevelContent
plEvent, err = r.DB.GetStateEvent(ctx, roomID, spec.MRoomPowerLevels, "") plEvent, err = r.DB.GetStateEvent(ctx, roomID, spec.MRoomPowerLevels, "")
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetStateEvent: %w", err) return true, false, fmt.Errorf("r.DB.GetStateEvent: %w", err)
} }
pls, err = plEvent.PowerLevels() pls, err = plEvent.PowerLevels()
if err != nil { if err != nil {
return fmt.Errorf("plEvent.PowerLevels: %w", err) return true, false, fmt.Errorf("plEvent.PowerLevels: %w", err)
} }
if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { if pls.UserLevel(senderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) {
response.Removed = false return true, false, nil
return nil
} }
} }
ev, err := r.DB.GetStateEvent(ctx, roomID, spec.MRoomCanonicalAlias, "") ev, err := r.DB.GetStateEvent(ctx, roomID, spec.MRoomCanonicalAlias, "")
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return true, false, err
} else if ev != nil { } else if ev != nil {
stateAlias := gjson.GetBytes(ev.Content(), "alias").Str stateAlias := gjson.GetBytes(ev.Content(), "alias").Str
// the alias to remove is currently set as the canonical alias, remove it // the alias to remove is currently set as the canonical alias, remove it
if stateAlias == request.Alias { if stateAlias == alias {
res, err := sjson.DeleteBytes(ev.Content(), "alias") res, err := sjson.DeleteBytes(ev.Content(), "alias")
if err != nil { if err != nil {
return err return true, false, err
} }
senderID := request.SenderID canonicalSenderID := ev.SenderID()
if request.SenderID != ev.SenderID() { canonicalSender, err := r.QueryUserIDForSender(ctx, *validRoomID, canonicalSenderID)
senderID = ev.SenderID() if err != nil || canonicalSender == nil {
} return true, false, err
sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
if err != nil || sender == nil {
return err
} }
validRoomID, err := spec.NewRoomID(roomID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
return err return true, false, err
} }
identity, err := r.SigningIdentityFor(ctx, *validRoomID, *sender) identity, err := r.SigningIdentityFor(ctx, *validRoomID, *canonicalSender)
if err != nil { if err != nil {
return err return true, false, err
} }
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
SenderID: string(senderID), SenderID: string(canonicalSenderID),
RoomID: ev.RoomID(), RoomID: ev.RoomID().String(),
Type: ev.Type(), Type: ev.Type(),
StateKey: ev.StateKey(), StateKey: ev.StateKey(),
Content: res, Content: res,
@ -208,34 +197,33 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto) eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto)
if err != nil { if err != nil {
return fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) return true, false, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
} }
if len(eventsNeeded.Tuples()) == 0 { if len(eventsNeeded.Tuples()) == 0 {
return errors.New("expecting state tuples for event builder, got none") return true, false, errors.New("expecting state tuples for event builder, got none")
} }
stateRes := &api.QueryLatestEventsAndStateResponse{} stateRes := &api.QueryLatestEventsAndStateResponse{}
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
return err return true, false, err
} }
newEvent, err := eventutil.BuildEvent(ctx, proto, &identity, time.Now(), &eventsNeeded, stateRes) newEvent, err := eventutil.BuildEvent(ctx, proto, &identity, time.Now(), &eventsNeeded, stateRes)
if err != nil { if err != nil {
return err return true, false, err
} }
err = api.SendEvents(ctx, r, api.KindNew, []*types.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false) err = api.SendEvents(ctx, r, api.KindNew, []*types.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false)
if err != nil { if err != nil {
return err return true, false, err
} }
} }
} }
// Remove the alias from the database // Remove the alias from the database
if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil { if err := r.DB.RemoveRoomAlias(ctx, alias); err != nil {
return err return true, false, err
} }
response.Removed = true return true, true, nil
return nil
} }

View file

@ -61,6 +61,7 @@ type RoomserverInternalAPI struct {
OutputProducer *producers.RoomEventProducer OutputProducer *producers.RoomEventProducer
PerspectiveServerNames []spec.ServerName PerspectiveServerNames []spec.ServerName
enableMetrics bool enableMetrics bool
defaultRoomVersion gomatrixserverlib.RoomVersion
} }
func NewRoomserverAPI( func NewRoomserverAPI(
@ -91,15 +92,9 @@ func NewRoomserverAPI(
NATSClient: nc, NATSClient: nc,
Durable: dendriteCfg.Global.JetStream.Durable("RoomserverInputConsumer"), Durable: dendriteCfg.Global.JetStream.Durable("RoomserverInputConsumer"),
ServerACLs: serverACLs, ServerACLs: serverACLs,
Queryer: &query.Queryer{ enableMetrics: enableMetrics,
DB: roomserverDB, defaultRoomVersion: dendriteCfg.RoomServer.DefaultRoomVersion,
Cache: caches, // perform-er structs + queryer struct get initialised when we have a federation sender to use
IsLocalServerName: dendriteCfg.Global.IsLocalServerName,
ServerACLs: serverACLs,
Cfg: dendriteCfg,
},
enableMetrics: enableMetrics,
// perform-er structs get initialised when we have a federation sender to use
} }
return a return a
} }
@ -111,6 +106,15 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.fsAPI = fsAPI r.fsAPI = fsAPI
r.KeyRing = keyRing r.KeyRing = keyRing
r.Queryer = &query.Queryer{
DB: r.DB,
Cache: r.Cache,
IsLocalServerName: r.Cfg.Global.IsLocalServerName,
ServerACLs: r.ServerACLs,
Cfg: r.Cfg,
FSAPI: fsAPI,
}
r.Inputer = &input.Inputer{ r.Inputer = &input.Inputer{
Cfg: &r.Cfg.RoomServer, Cfg: &r.Cfg.RoomServer,
ProcessContext: r.ProcessContext, ProcessContext: r.ProcessContext,
@ -123,6 +127,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
ServerName: r.ServerName, ServerName: r.ServerName,
SigningIdentity: r.SigningIdentityFor, SigningIdentity: r.SigningIdentityFor,
FSAPI: fsAPI, FSAPI: fsAPI,
RSAPI: r,
KeyRing: keyRing, KeyRing: keyRing,
ACLs: r.ServerACLs, ACLs: r.ServerACLs,
Queryer: r.Queryer, Queryer: r.Queryer,
@ -215,6 +220,10 @@ func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalA
r.asAPI = asAPI r.asAPI = asAPI
} }
func (r *RoomserverInternalAPI) DefaultRoomVersion() gomatrixserverlib.RoomVersion {
return r.defaultRoomVersion
}
func (r *RoomserverInternalAPI) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) { func (r *RoomserverInternalAPI) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) {
return r.Inviter.IsKnownRoom(ctx, roomID) return r.Inviter.IsKnownRoom(ctx, roomID)
} }
@ -230,7 +239,7 @@ func (r *RoomserverInternalAPI) HandleInvite(
if err != nil { if err != nil {
return err return err
} }
return r.OutputProducer.ProduceRoomEvents(inviteEvent.RoomID(), outputEvents) return r.OutputProducer.ProduceRoomEvents(inviteEvent.RoomID().String(), outputEvents)
} }
func (r *RoomserverInternalAPI) PerformCreateRoom( func (r *RoomserverInternalAPI) PerformCreateRoom(
@ -318,7 +327,7 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
return fclient.SigningIdentity{ return fclient.SigningIdentity{
PrivateKey: privKey, PrivateKey: privKey,
KeyID: "ed25519:1", KeyID: "ed25519:1",
ServerName: "self", ServerName: spec.ServerName(spec.SenderIDFromPseudoIDKey(privKey)),
}, nil }, nil
} }
identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain()) identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain())

View file

@ -218,9 +218,9 @@ func loadAuthEvents(
roomID := "" roomID := ""
for _, ev := range result.events { for _, ev := range result.events {
if roomID == "" { if roomID == "" {
roomID = ev.RoomID() roomID = ev.RoomID().String()
} }
if ev.RoomID() != roomID { if ev.RoomID().String() != roomID {
result.valid = false result.valid = false
break break
} }

View file

@ -54,7 +54,7 @@ func UpdateToInviteMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(), RoomID: add.RoomID().String(),
Membership: spec.Join, Membership: spec.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetSenderID: spec.SenderID(*add.StateKey()), TargetSenderID: spec.SenderID(*add.StateKey()),
@ -396,7 +396,7 @@ BFSLoop:
// It's nasty that we have to extract the room ID from an event, but many federation requests // It's nasty that we have to extract the room ID from an event, but many federation requests
// only talk in event IDs, no room IDs at all (!!!) // only talk in event IDs, no room IDs at all (!!!)
ev := events[0] ev := events[0]
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID()) isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID().String())
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
} }
@ -419,7 +419,7 @@ BFSLoop:
// hasn't been seen before. // hasn't been seen before.
if !visited[pre] { if !visited[pre] {
visited[pre] = true visited[pre] = true
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier) allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID().String(), pre, serverName, isServerInRoom, querier)
if err != nil { if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event", "Error checking if allowed to see event",

View file

@ -83,6 +83,7 @@ type Inputer struct {
ServerName spec.ServerName ServerName spec.ServerName
SigningIdentity func(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) SigningIdentity func(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
FSAPI fedapi.RoomserverFederationAPI FSAPI fedapi.RoomserverFederationAPI
RSAPI api.RoomserverInternalAPI
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
ACLs *acls.ServerACLs ACLs *acls.ServerACLs
InputRoomEventTopic string InputRoomEventTopic string
@ -357,7 +358,7 @@ func (r *Inputer) queueInputRoomEvents(
// For each event, marshal the input room event and then // For each event, marshal the input room event and then
// send it into the input queue. // send it into the input queue.
for _, e := range request.InputRoomEvents { for _, e := range request.InputRoomEvents {
roomID := e.Event.RoomID() roomID := e.Event.RoomID().String()
subj := r.Cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEventSubj(roomID)) subj := r.Cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEventSubj(roomID))
msg := &nats.Msg{ msg := &nats.Msg{
Subject: subj, Subject: subj,

View file

@ -87,7 +87,7 @@ func (r *Inputer) processRoomEvent(
} }
trace, ctx := internal.StartRegion(ctx, "processRoomEvent") trace, ctx := internal.StartRegion(ctx, "processRoomEvent")
trace.SetTag("room_id", input.Event.RoomID()) trace.SetTag("room_id", input.Event.RoomID().String())
trace.SetTag("event_id", input.Event.EventID()) trace.SetTag("event_id", input.Event.EventID())
defer trace.EndRegion() defer trace.EndRegion()
@ -96,7 +96,7 @@ func (r *Inputer) processRoomEvent(
defer func() { defer func() {
timetaken := time.Since(started) timetaken := time.Since(started)
processRoomEventDuration.With(prometheus.Labels{ processRoomEventDuration.With(prometheus.Labels{
"room_id": input.Event.RoomID(), "room_id": input.Event.RoomID().String(),
}).Observe(float64(timetaken.Milliseconds())) }).Observe(float64(timetaken.Milliseconds()))
}() }()
@ -105,7 +105,7 @@ func (r *Inputer) processRoomEvent(
event := headered.PDU event := headered.PDU
logger := util.GetLogger(ctx).WithFields(logrus.Fields{ logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
"room_id": event.RoomID(), "room_id": event.RoomID().String(),
"kind": input.Kind, "kind": input.Kind,
"origin": input.Origin, "origin": input.Origin,
"type": event.Type(), "type": event.Type(),
@ -120,19 +120,15 @@ func (r *Inputer) processRoomEvent(
// Don't waste time processing the event if the room doesn't exist. // Don't waste time processing the event if the room doesn't exist.
// A room entry locally will only be created in response to a create // A room entry locally will only be created in response to a create
// event. // event.
roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID()) roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID().String())
if rerr != nil { if rerr != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", rerr) return fmt.Errorf("r.DB.RoomInfo: %w", rerr)
} }
isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("") isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("")
if roomInfo == nil && !isCreateEvent { if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) return fmt.Errorf("room %s does not exist for event %s", event.RoomID().String(), event.EventID())
} }
validRoomID, err := spec.NewRoomID(event.RoomID()) sender, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
return err
}
sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
} }
@ -179,7 +175,7 @@ func (r *Inputer) processRoomEvent(
// If we have missing events (auth or prev), we build a list of servers to ask // If we have missing events (auth or prev), we build a list of servers to ask
if missingAuth || missingPrev { if missingAuth || missingPrev {
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
RoomID: event.RoomID(), RoomID: event.RoomID().String(),
ExcludeSelf: true, ExcludeSelf: true,
ExcludeBlacklisted: true, ExcludeBlacklisted: true,
} }
@ -250,6 +246,21 @@ func (r *Inputer) processRoomEvent(
// really do anything with the event other than reject it at this point. // really do anything with the event other than reject it at this point.
isRejected = true isRejected = true
rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err)
switch e := err.(type) {
case gomatrixserverlib.EventValidationError:
if e.Persistable && stateSnapshot != nil {
// We retrieved some state and we ended up having to call /state_ids for
// the new event in question (probably because closing the gap by using
// /get_missing_events didn't do what we hoped) so we'll instead overwrite
// the state snapshot with the newly resolved state.
missingPrev = false
input.HasState = true
input.StateEventIDs = make([]string, 0, len(stateSnapshot.StateEvents))
for _, se := range stateSnapshot.StateEvents {
input.StateEventIDs = append(input.StateEventIDs, se.EventID())
}
}
}
} else if stateSnapshot != nil { } else if stateSnapshot != nil {
// We retrieved some state and we ended up having to call /state_ids for // We retrieved some state and we ended up having to call /state_ids for
// the new event in question (probably because closing the gap by using // the new event in question (probably because closing the gap by using
@ -380,12 +391,12 @@ func (r *Inputer) processRoomEvent(
// Request the room info again — it's possible that the room has been // Request the room info again — it's possible that the room has been
// created by now if it didn't exist already. // created by now if it didn't exist already.
roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID()) roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID().String())
if err != nil { if err != nil {
return fmt.Errorf("updater.RoomInfo: %w", err) return fmt.Errorf("updater.RoomInfo: %w", err)
} }
if roomInfo == nil { if roomInfo == nil {
return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID().String())
} }
if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) {
@ -433,6 +444,24 @@ func (r *Inputer) processRoomEvent(
return nil return nil
} }
// TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one.
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember {
mapping := gomatrixserverlib.MemberContent{}
if err = json.Unmarshal(event.Content(), &mapping); err != nil {
return err
}
if mapping.MXIDMapping != nil {
storeUserID, userErr := spec.NewUserID(mapping.MXIDMapping.UserID, true)
if userErr != nil {
return userErr
}
err = r.RSAPI.StoreUserRoomPublicKey(ctx, mapping.MXIDMapping.UserRoomKey, *storeUserID, event.RoomID())
if err != nil {
return fmt.Errorf("failed storing user room public key: %w", err)
}
}
}
switch input.Kind { switch input.Kind {
case api.KindNew: case api.KindNew:
if err = r.updateLatestEvents( if err = r.updateLatestEvents(
@ -448,7 +477,7 @@ func (r *Inputer) processRoomEvent(
return fmt.Errorf("r.updateLatestEvents: %w", err) return fmt.Errorf("r.updateLatestEvents: %w", err)
} }
case api.KindOld: case api.KindOld:
err = r.OutputProducer.ProduceRoomEvents(event.RoomID(), []api.OutputEvent{ err = r.OutputProducer.ProduceRoomEvents(event.RoomID().String(), []api.OutputEvent{
{ {
Type: api.OutputTypeOldRoomEvent, Type: api.OutputTypeOldRoomEvent,
OldRoomEvent: &api.OutputOldRoomEvent{ OldRoomEvent: &api.OutputOldRoomEvent{
@ -474,7 +503,7 @@ func (r *Inputer) processRoomEvent(
// so notify downstream components to redact this event - they should have it if they've // so notify downstream components to redact this event - they should have it if they've
// been tracking our output log. // been tracking our output log.
if redactedEventID != "" { if redactedEventID != "" {
err = r.OutputProducer.ProduceRoomEvents(event.RoomID(), []api.OutputEvent{ err = r.OutputProducer.ProduceRoomEvents(event.RoomID().String(), []api.OutputEvent{
{ {
Type: api.OutputTypeRedactedEvent, Type: api.OutputTypeRedactedEvent,
RedactedEvent: &api.OutputRedactedEvent{ RedactedEvent: &api.OutputRedactedEvent{
@ -503,7 +532,7 @@ func (r *Inputer) processRoomEvent(
// handleRemoteRoomUpgrade updates published rooms and room aliases // handleRemoteRoomUpgrade updates published rooms and room aliases
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
oldRoomID := event.RoomID() oldRoomID := event.RoomID().String()
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, string(event.SenderID())) return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, string(event.SenderID()))
} }
@ -563,7 +592,7 @@ func (r *Inputer) processStateBefore(
StateKey: "", StateKey: "",
}) })
stateBeforeReq := &api.QueryStateAfterEventsRequest{ stateBeforeReq := &api.QueryStateAfterEventsRequest{
RoomID: event.RoomID(), RoomID: event.RoomID().String(),
PrevEventIDs: event.PrevEventIDs(), PrevEventIDs: event.PrevEventIDs(),
StateToFetch: tuplesNeeded, StateToFetch: tuplesNeeded,
} }
@ -573,7 +602,7 @@ func (r *Inputer) processStateBefore(
} }
switch { switch {
case !stateBeforeRes.RoomExists: case !stateBeforeRes.RoomExists:
rejectionErr = fmt.Errorf("room %q does not exist", event.RoomID()) rejectionErr = fmt.Errorf("room %q does not exist", event.RoomID().String())
return return
case !stateBeforeRes.PrevEventsExist: case !stateBeforeRes.PrevEventsExist:
rejectionErr = fmt.Errorf("prev events of %q are not known", event.EventID()) rejectionErr = fmt.Errorf("prev events of %q are not known", event.EventID())
@ -674,7 +703,7 @@ func (r *Inputer) fetchAuthEvents(
// Request the entire auth chain for the event in question. This should // Request the entire auth chain for the event in question. This should
// contain all of the auth events — including ones that we already know — // contain all of the auth events — including ones that we already know —
// so we'll need to filter through those in the next section. // so we'll need to filter through those in the next section.
res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.Version(), event.RoomID(), event.EventID()) res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.Version(), event.RoomID().String(), event.EventID())
if err != nil { if err != nil {
logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err)
continue continue
@ -833,25 +862,20 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
latestReq := &api.QueryLatestEventsAndStateRequest{ latestReq := &api.QueryLatestEventsAndStateRequest{
RoomID: event.RoomID(), RoomID: event.RoomID().String(),
} }
latestRes := &api.QueryLatestEventsAndStateResponse{} latestRes := &api.QueryLatestEventsAndStateResponse{}
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
return err return err
} }
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
prevEvents := latestRes.LatestEvents prevEvents := latestRes.LatestEvents
for _, memberEvent := range memberEvents { for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil { if memberEvent.StateKey() == nil {
continue continue
} }
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey())) memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*memberEvent.StateKey()))
if err != nil { if err != nil {
continue continue
} }
@ -879,7 +903,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
stateKey := *memberEvent.StateKey() stateKey := *memberEvent.StateKey()
fledglingEvent := &gomatrixserverlib.ProtoEvent{ fledglingEvent := &gomatrixserverlib.ProtoEvent{
RoomID: event.RoomID(), RoomID: event.RoomID().String(),
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: &stateKey, StateKey: &stateKey,
SenderID: stateKey, SenderID: stateKey,
@ -895,17 +919,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
return err return err
} }
validRoomID, err := spec.NewRoomID(event.RoomID()) signingIdentity, err := r.SigningIdentity(ctx, event.RoomID(), *memberUserID)
if err != nil {
return err
}
userID, err := spec.NewUserID(stateKey, true)
if err != nil {
return err
}
signingIdentity, err := r.SigningIdentity(ctx, *validRoomID, *userID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -197,7 +197,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in // send the event asynchronously but we would need to ensure that 1) the events are written to the log in
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now. // necessary bookkeeping we'll keep the event sending synchronous for now.
if err = u.api.OutputProducer.ProduceRoomEvents(u.event.RoomID(), updates); err != nil { if err = u.api.OutputProducer.ProduceRoomEvents(u.event.RoomID().String(), updates); err != nil {
return fmt.Errorf("u.api.WriteOutputEvents: %w", err) return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
} }
@ -290,7 +290,7 @@ func (u *latestEventsUpdater) latestState() error {
if removed := len(u.removed) - len(u.added); !u.rewritesState && removed > 0 { if removed := len(u.removed) - len(u.added); !u.rewritesState && removed > 0 {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"event_id": u.event.EventID(), "event_id": u.event.EventID(),
"room_id": u.event.RoomID(), "room_id": u.event.RoomID().String(),
"old_state_nid": u.oldStateNID, "old_state_nid": u.oldStateNID,
"new_state_nid": u.newStateNID, "new_state_nid": u.newStateNID,
"old_latest": u.oldLatest.EventIDs(), "old_latest": u.oldLatest.EventIDs(),

View file

@ -139,11 +139,7 @@ func (r *Inputer) updateMembership(
func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
isTargetLocalUser := false isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil { if statekey := event.StateKey(); statekey != nil {
validRoomID, err := spec.NewRoomID(event.RoomID()) userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey))
if err != nil {
return isTargetLocalUser
}
userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey))
if err != nil || userID == nil { if err != nil || userID == nil {
return isTargetLocalUser return isTargetLocalUser
} }
@ -168,7 +164,7 @@ func updateToJoinMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(), RoomID: add.RoomID().String(),
Membership: spec.Join, Membership: spec.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetSenderID: spec.SenderID(*add.StateKey()), TargetSenderID: spec.SenderID(*add.StateKey()),
@ -195,7 +191,7 @@ func updateToLeaveMembership(
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{ RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
RoomID: add.RoomID(), RoomID: add.RoomID().String(),
Membership: newMembership, Membership: newMembership,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetSenderID: spec.SenderID(*add.StateKey()), TargetSenderID: spec.SenderID(*add.StateKey()),

View file

@ -84,7 +84,7 @@ func (t *missingStateReq) processEventWithMissingState(
// need to fallback to /state. // need to fallback to /state.
t.log = util.GetLogger(ctx).WithFields(map[string]interface{}{ t.log = util.GetLogger(ctx).WithFields(map[string]interface{}{
"txn_event": e.EventID(), "txn_event": e.EventID(),
"room_id": e.RoomID(), "room_id": e.RoomID().String(),
"txn_prev_events": e.PrevEventIDs(), "txn_prev_events": e.PrevEventIDs(),
}) })
@ -259,12 +259,20 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e
// Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
// the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event.
var states []*respState var states []*respState
var validationError error
for _, prevEventID := range e.PrevEventIDs() { for _, prevEventID := range e.PrevEventIDs() {
// Look up what the state is after the backward extremity. This will either // Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will // come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not. // come from a remote server via /state_ids if not.
prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID(), prevEventID) prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID().String(), prevEventID)
if err != nil { switch err2 := err.(type) {
case gomatrixserverlib.EventValidationError:
if !err2.Persistable {
return nil, err2
}
validationError = err2
case nil:
default:
return nil, fmt.Errorf("t.lookupStateAfterEvent: %w", err) return nil, fmt.Errorf("t.lookupStateAfterEvent: %w", err)
} }
// Append the state onto the collected state. We'll run this through the // Append the state onto the collected state. We'll run this through the
@ -308,15 +316,22 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e
} }
// There's more than one previous state - run them all through state res // There's more than one previous state - run them all through state res
var err error var err error
t.roomsMu.Lock(e.RoomID()) t.roomsMu.Lock(e.RoomID().String())
resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e) resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e)
t.roomsMu.Unlock(e.RoomID()) t.roomsMu.Unlock(e.RoomID().String())
if err != nil { switch err2 := err.(type) {
case gomatrixserverlib.EventValidationError:
if !err2.Persistable {
return nil, err2
}
validationError = err2
case nil:
default:
return nil, fmt.Errorf("t.resolveStatesAndCheck: %w", err) return nil, fmt.Errorf("t.resolveStatesAndCheck: %w", err)
} }
} }
return resolvedState, nil return resolvedState, validationError
} }
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
@ -339,8 +354,15 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
} }
// fetch the event we're missing and add it to the pile // fetch the event we're missing and add it to the pile
var validationError error
h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false) h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false)
switch err.(type) { switch e := err.(type) {
case gomatrixserverlib.EventValidationError:
if !e.Persistable {
logrus.WithContext(ctx).WithError(err).Errorf("Failed to look up event %s", eventID)
return nil, false, e
}
validationError = e
case verifySigError: case verifySigError:
return respState, false, nil return respState, false, nil
case nil: case nil:
@ -365,7 +387,7 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
} }
} }
return respState, false, nil return respState, false, validationError
} }
func (t *missingStateReq) cacheAndReturn(ev gomatrixserverlib.PDU) gomatrixserverlib.PDU { func (t *missingStateReq) cacheAndReturn(ev gomatrixserverlib.PDU) gomatrixserverlib.PDU {
@ -476,19 +498,32 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, err := t.db.IsEventRejected(ctx, t.roomInfo.RoomNID, eventID)
if err != nil {
return true
}
return isRejected
},
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// apply the current event // apply the current event
var validationError error
retryAllowedState: retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID().String(), missing.AuthEventID, true)
switch err2.(type) { switch e := err2.(type) {
case gomatrixserverlib.EventValidationError:
if !e.Persistable {
return nil, e
}
validationError = e
case verifySigError: case verifySigError:
return &parsedRespState{ return &parsedRespState{
AuthEvents: authEventList, AuthEvents: authEventList,
@ -509,7 +544,7 @@ retryAllowedState:
return &parsedRespState{ return &parsedRespState{
AuthEvents: authEventList, AuthEvents: authEventList,
StateEvents: resolvedStateEvents, StateEvents: resolvedStateEvents,
}, nil }, validationError
} }
// get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject,
@ -518,7 +553,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
trace, ctx := internal.StartRegion(ctx, "getMissingEvents") trace, ctx := internal.StartRegion(ctx, "getMissingEvents")
defer trace.EndRegion() defer trace.EndRegion()
logger := t.log.WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := t.log.WithField("event_id", e.EventID()).WithField("room_id", e.RoomID().String())
latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID)
if err != nil { if err != nil {
return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err) return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err)
@ -532,7 +567,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
var missingResp *fclient.RespMissingEvents var missingResp *fclient.RespMissingEvents
for _, server := range t.servers { for _, server := range t.servers {
var m fclient.RespMissingEvents var m fclient.RespMissingEvents
if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), fclient.MissingEvents{ if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID().String(), fclient.MissingEvents{
Limit: 20, Limit: 20,
// The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events.
EarliestEvents: latestEvents, EarliestEvents: latestEvents,
@ -779,7 +814,11 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
// Define what we'll do in order to fetch the missing event ID. // Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) { fetch := func(missingEventID string) {
h, herr := t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) h, herr := t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false)
switch herr.(type) { switch e := herr.(type) {
case gomatrixserverlib.EventValidationError:
if !e.Persistable {
return
}
case verifySigError: case verifySigError:
return return
case nil: case nil:
@ -869,6 +908,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
} }
var event gomatrixserverlib.PDU var event gomatrixserverlib.PDU
found := false found := false
var validationError error
serverLoop:
for _, serverName := range t.servers { for _, serverName := range t.servers {
reqctx, cancel := context.WithTimeout(ctx, time.Second*30) reqctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
@ -886,12 +927,25 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
continue continue
} }
event, err = verImpl.NewEventFromUntrustedJSON(txn.PDUs[0]) event, err = verImpl.NewEventFromUntrustedJSON(txn.PDUs[0])
if err != nil { switch e := err.(type) {
case gomatrixserverlib.EventValidationError:
// If the event is persistable, e.g. failed validation for exceeding
// byte sizes, we can "accept" the event.
if e.Persistable {
validationError = e
found = true
break serverLoop
}
// If we can't persist the event, we probably can't do so with results
// from other servers, so also break the loop.
break serverLoop
case nil:
found = true
break serverLoop
default:
t.log.WithError(err).WithField("missing_event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event") t.log.WithError(err).WithField("missing_event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event")
continue continue
} }
found = true
break
} }
if !found { if !found {
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
@ -903,7 +957,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }
return t.cacheAndReturn(event), nil return t.cacheAndReturn(event), validationError
} }
func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error { func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error {

View file

@ -161,12 +161,12 @@ func (r *Admin) PerformAdminEvacuateUser(
return nil, fmt.Errorf("can only evacuate local users using this endpoint") return nil, fmt.Errorf("can only evacuate local users using this endpoint")
} }
roomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Join) roomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Join)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Invite) inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Invite)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, err return nil, err
} }
@ -204,18 +204,6 @@ func (r *Admin) PerformAdminPurgeRoom(
return err return err
} }
// Evacuate the room before purging it from the database
evacAffected, err := r.PerformAdminEvacuateRoom(ctx, roomID)
if err != nil {
logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to evacuate room before purging")
return err
}
logrus.WithFields(logrus.Fields{
"room_id": roomID,
"evacuated_users": len(evacAffected),
}).Warn("Evacuated room, purging room from roomserver now")
logrus.WithField("room_id", roomID).Warn("Purging room from roomserver") logrus.WithField("room_id", roomID).Warn("Purging room from roomserver")
if err := r.DB.PurgeRoom(ctx, roomID); err != nil { if err := r.DB.PurgeRoom(ctx, roomID); err != nil {
logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to purge room from roomserver") logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to purge room from roomserver")
@ -304,10 +292,12 @@ func (r *Admin) PerformAdminDownloadState(
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return err return err
} else if senderID == nil {
return fmt.Errorf("sender ID not found for %s in %s", *fullUserID, *validRoomID)
} }
proto := &gomatrixserverlib.ProtoEvent{ proto := &gomatrixserverlib.ProtoEvent{
Type: "org.matrix.dendrite.state_download", Type: "org.matrix.dendrite.state_download",
SenderID: string(senderID), SenderID: string(*senderID),
RoomID: roomID, RoomID: roomID,
Content: spec.RawJSON("{}"), Content: spec.RawJSON("{}"),
} }

View file

@ -301,7 +301,7 @@ func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent
return ids, nil return ids, nil
} }
if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") { if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") {
util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room") util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID().String()).Info("Backfilled to the beginning of the room")
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{} b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{}
return nil, nil return nil, nil
} }
@ -494,11 +494,7 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates. // Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool) serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents { for _, event := range memberEvents {
validRoomID, err := spec.NewRoomID(event.RoomID()) if sender, err := b.querier.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
if err != nil {
continue
}
if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil {
serverSet[sender.Domain()] = true serverSet[sender.Domain()] = true
} }
} }

View file

@ -90,7 +90,16 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} else { } else {
senderID = spec.SenderID(userID.String()) senderID = spec.SenderID(userID.String())
} }
createContent["creator"] = senderID
// TODO: Maybe, at some point, GMSL should return the events to create, so we can define the version
// entirely there.
switch createRequest.RoomVersion {
case gomatrixserverlib.RoomVersionV11:
// RoomVersionV11 removed the creator field from the create content: https://github.com/matrix-org/matrix-spec-proposals/pull/2175
default:
createContent["creator"] = senderID
}
createContent["room_version"] = createRequest.RoomVersion createContent["room_version"] = createRequest.RoomVersion
powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID)) powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID))
joinRuleContent := gomatrixserverlib.JoinRuleContent{ joinRuleContent := gomatrixserverlib.JoinRuleContent{
@ -195,7 +204,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
// sign all events with the pseudo ID key // sign all events with the pseudo ID key
identity = &fclient.SigningIdentity{ identity = &fclient.SigningIdentity{
ServerName: "self", ServerName: spec.ServerName(spec.SenderIDFromPseudoIDKey(pseudoIDKey)),
KeyID: "ed25519:1", KeyID: "ed25519:1",
PrivateKey: pseudoIDKey, PrivateKey: pseudoIDKey,
} }
@ -433,23 +442,16 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
// from creating the room but still failing due to the alias having already // from creating the room but still failing due to the alias having already
// been taken. // been taken.
if roomAlias != "" { if roomAlias != "" {
aliasReq := api.SetRoomAliasRequest{ aliasAlreadyExists, aliasErr := c.RSAPI.SetRoomAlias(ctx, senderID, roomID, roomAlias)
Alias: roomAlias, if aliasErr != nil {
RoomID: roomID.String(), util.GetLogger(ctx).WithError(aliasErr).Error("aliasAPI.SetRoomAlias failed")
UserID: userID.String(),
}
var aliasResp api.SetRoomAliasResponse
err = c.RSAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
if aliasResp.AliasExists { if aliasAlreadyExists {
return "", &util.JSONResponse{ return "", &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.RoomInUse("Room alias already exists."), JSON: spec.RoomInUse("Room alias already exists."),
@ -489,7 +491,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
// Process the invites. // Process the invites.
var inviteEvent *types.HeaderedEvent
for _, invitee := range createRequest.InvitedUsers { for _, invitee := range createRequest.InvitedUsers {
inviteeUserID, userIDErr := spec.NewUserID(invitee, true) inviteeUserID, userIDErr := spec.NewUserID(invitee, true)
if userIDErr != nil { if userIDErr != nil {
@ -499,54 +500,21 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID)
if queryErr != nil {
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
inviteeString := string(inviteeSenderID)
proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID),
RoomID: roomID.String(),
Type: "m.room.member",
StateKey: &inviteeString,
}
content := gomatrixserverlib.MemberContent{
Membership: spec.Invite,
DisplayName: createRequest.UserDisplayName,
AvatarURL: createRequest.UserAvatarURL,
Reason: "",
IsDirect: createRequest.IsDirect,
}
if err = proto.SetContent(content); err != nil {
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
// Build the invite event.
inviteEvent, err = eventutil.QueryAndBuildEvent(ctx, &proto, identity, createRequest.EventTime, c.RSAPI, nil)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed")
continue
}
inviteStrippedState := append(
globalStrippedState,
gomatrixserverlib.NewInviteStrippedState(inviteEvent.PDU),
)
// Send the invite event to the roomserver.
event := inviteEvent
err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{ err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
Event: event, InviteInput: api.InviteInput{
InviteRoomState: inviteStrippedState, RoomID: roomID,
RoomVersion: event.Version(), Inviter: userID,
Invitee: *inviteeUserID,
DisplayName: createRequest.UserDisplayName,
AvatarURL: createRequest.UserAvatarURL,
Reason: "",
IsDirect: createRequest.IsDirect,
KeyID: createRequest.KeyID,
PrivateKey: createRequest.PrivateKey,
EventTime: createRequest.EventTime,
},
InviteRoomState: globalStrippedState,
SendAsServer: string(userID.Domain()), SendAsServer: string(userID.Domain()),
}) })
switch e := err.(type) { switch e := err.(type) {

View file

@ -16,6 +16,7 @@ package perform
import ( import (
"context" "context"
"crypto/ed25519"
"fmt" "fmt"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
@ -99,16 +100,12 @@ func (r *Inviter) ProcessInviteMembership(
var outputUpdates []api.OutputEvent var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater var updater *shared.MembershipUpdater
validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
if err != nil {
return nil, err
}
userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
if err != nil { if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain())
if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID().String(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
} }
outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{ outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{
@ -129,65 +126,104 @@ func (r *Inviter) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
) error { ) error {
event := req.Event senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
if err != nil {
validRoomID, err := spec.NewRoomID(event.RoomID()) return err
} else if senderID == nil {
return fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
}
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
if err != nil { if err != nil {
return err return err
} }
sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) proto := gomatrixserverlib.ProtoEvent{
if err != nil { SenderID: string(*senderID),
return spec.InvalidParam("The sender user ID is invalid") RoomID: req.InviteInput.RoomID.String(),
Type: "m.room.member",
} }
if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
content := gomatrixserverlib.MemberContent{
Membership: spec.Invite,
DisplayName: req.InviteInput.DisplayName,
AvatarURL: req.InviteInput.AvatarURL,
Reason: req.InviteInput.Reason,
IsDirect: req.InviteInput.IsDirect,
}
if err = proto.SetContent(content); err != nil {
return err
}
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
} }
if event.StateKey() == nil || *event.StateKey() == "" { isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
return fmt.Errorf("invite must be a state event")
}
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if err != nil || invitedUser == nil {
return spec.InvalidParam("Could not find the matching senderID for this user")
}
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
// If we're inviting a local user, we can generate the needed pseudoID key here. (if needed) signingKey := req.InviteInput.PrivateKey
if isTargetLocal { if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
var roomVersion gomatrixserverlib.RoomVersion signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
roomVersion, err = r.DB.GetRoomVersion(ctx, event.RoomID())
if err != nil { if err != nil {
return err return err
} }
switch roomVersion {
case gomatrixserverlib.RoomVersionPseudoIDs:
_, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *invitedUser, *validRoomID)
if err != nil {
return err
}
}
}
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser)
if err != nil {
return fmt.Errorf("failed looking up senderID for invited user")
} }
input := gomatrixserverlib.PerformInviteInput{ input := gomatrixserverlib.PerformInviteInput{
RoomID: *validRoomID, RoomID: req.InviteInput.RoomID,
InviteEvent: event.PDU, RoomVersion: info.RoomVersion,
InvitedUser: *invitedUser, Inviter: req.InviteInput.Inviter,
InvitedSenderID: invitedSenderID, Invitee: req.InviteInput.Invitee,
IsTargetLocal: isTargetLocal, IsTargetLocal: isTargetLocal,
EventTemplate: proto,
StrippedState: req.InviteRoomState, StrippedState: req.InviteRoomState,
KeyID: req.InviteInput.KeyID,
SigningKey: signingKey,
EventTime: req.InviteInput.EventTime,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB, r.RSAPI}, StateQuerier: &QueryState{r.DB, r.RSAPI},
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
SenderIDQuerier: func(roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
return r.RSAPI.QuerySenderIDForUser(ctx, roomID, userID)
},
SenderIDCreator: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) {
key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
if keyErr != nil {
return "", nil, keyErr
}
return spec.SenderIDFromPseudoIDKey(key), key, nil
},
EventQuerier: func(ctx context.Context, roomID spec.RoomID, eventsNeeded []gomatrixserverlib.StateKeyTuple) (gomatrixserverlib.LatestEvents, error) {
req := api.QueryLatestEventsAndStateRequest{RoomID: roomID.String(), StateToFetch: eventsNeeded}
res := api.QueryLatestEventsAndStateResponse{}
err = r.RSAPI.QueryLatestEventsAndState(ctx, &req, &res)
if err != nil {
return gomatrixserverlib.LatestEvents{}, nil
}
stateEvents := []gomatrixserverlib.PDU{}
for _, event := range res.StateEvents {
stateEvents = append(stateEvents, event.PDU)
}
return gomatrixserverlib.LatestEvents{
RoomExists: res.RoomExists,
StateEvents: stateEvents,
PrevEventIDs: res.LatestEvents,
Depth: res.Depth,
}, nil
},
StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error {
storeUserID, userErr := spec.NewUserID(userIDRaw, true)
if userErr != nil {
return userErr
}
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
},
} }
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
if err != nil { if err != nil {
switch e := err.(type) { switch e := err.(type) {
@ -199,20 +235,6 @@ func (r *Inviter) PerformInvite(
return err return err
} }
// Use the returned event if there was one (due to federation), otherwise
// send the original invite event to the roomserver.
if inviteEvent == nil {
inviteEvent = event
}
// if we invited a local user, we can also create a user room key, if it doesn't exist yet.
if isTargetLocal && event.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
_, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *invitedUser, *validRoomID)
if err != nil {
return fmt.Errorf("failed to get user room private key: %w", err)
}
}
// Send the invite event to the roomserver input stream. This will // Send the invite event to the roomserver input stream. This will
// notify existing users in the room about the invite, update the // notify existing users in the room about the invite, update the
// membership table and ensure that the event is ready and available // membership table and ensure that the event is ready and available
@ -223,7 +245,7 @@ func (r *Inviter) PerformInvite(
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: &types.HeaderedEvent{PDU: inviteEvent}, Event: &types.HeaderedEvent{PDU: inviteEvent},
Origin: sender.Domain(), Origin: req.InviteInput.Inviter.Domain(),
SendAsServer: req.SendAsServer, SendAsServer: req.SendAsServer,
}, },
}, },
@ -231,7 +253,7 @@ func (r *Inviter) PerformInvite(
inputRes := &api.InputRoomEventsResponse{} inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
if err := inputRes.Err(); err != nil { if err := inputRes.Err(); err != nil {
util.GetLogger(ctx).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
return api.ErrNotAllowed{Err: err} return api.ErrNotAllowed{Err: err}
} }

View file

@ -201,11 +201,11 @@ func (r *Joiner) performJoinRoomByID(
if err == nil && info != nil { if err == nil && info != nil {
switch info.RoomVersion { switch info.RoomVersion {
case gomatrixserverlib.RoomVersionPseudoIDs: case gomatrixserverlib.RoomVersionPseudoIDs:
senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID) senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
if err == nil { if queryErr == nil {
checkInvitePending = true checkInvitePending = true
} }
if senderID == "" { if senderIDPtr == nil {
// create user room key if needed // create user room key if needed
key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
if keyErr != nil { if keyErr != nil {
@ -213,6 +213,8 @@ func (r *Joiner) performJoinRoomByID(
return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr) return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr)
} }
senderID = spec.SenderIDFromPseudoIDKey(key) senderID = spec.SenderIDFromPseudoIDKey(key)
} else {
senderID = *senderIDPtr
} }
default: default:
checkInvitePending = true checkInvitePending = true
@ -274,7 +276,6 @@ func (r *Joiner) performJoinRoomByID(
// If we should do a forced federated join then do that. // If we should do a forced federated join then do that.
var joinedVia spec.ServerName var joinedVia spec.ServerName
if forceFederatedJoin { if forceFederatedJoin {
// TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet
joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
return req.RoomIDOrAlias, joinedVia, err return req.RoomIDOrAlias, joinedVia, err
} }
@ -286,10 +287,7 @@ func (r *Joiner) performJoinRoomByID(
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
var buildRes rsAPI.QueryLatestEventsAndStateResponse var buildRes rsAPI.QueryLatestEventsAndStateResponse
identity, err := r.RSAPI.SigningIdentityFor(ctx, *roomID, *userID) identity := r.Cfg.Matrix.SigningIdentity
if err != nil {
return "", "", fmt.Errorf("error joining local room: %q", err)
}
// at this point we know we have an existing room // at this point we know we have an existing room
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
@ -313,7 +311,7 @@ func (r *Joiner) performJoinRoomByID(
// sign the event with the pseudo ID key // sign the event with the pseudo ID key
identity = fclient.SigningIdentity{ identity = fclient.SigningIdentity{
ServerName: "self", ServerName: spec.ServerName(spec.SenderIDFromPseudoIDKey(pseudoIDKey)),
KeyID: "ed25519:1", KeyID: "ed25519:1",
PrivateKey: pseudoIDKey, PrivateKey: pseudoIDKey,
} }

View file

@ -73,6 +73,7 @@ func (r *Leaver) PerformLeave(
return nil, fmt.Errorf("room ID %q is invalid", req.RoomID) return nil, fmt.Errorf("room ID %q is invalid", req.RoomID)
} }
// nolint:gocyclo
func (r *Leaver) performLeaveRoomByID( func (r *Leaver) performLeaveRoomByID(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
@ -83,20 +84,30 @@ func (r *Leaver) performLeaveRoomByID(
return nil, err return nil, err
} }
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver) leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
if err != nil { if err != nil || leaver == nil {
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
} }
// If there's an invite outstanding for the room then respond to // If there's an invite outstanding for the room then respond to
// that. // that.
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver) isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, *leaver)
if err == nil && isInvitePending { if err == nil && isInvitePending {
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser) sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
if serr != nil || sender == nil { if serr != nil {
return nil, fmt.Errorf("sender %q has no matching userID", senderUser) return nil, fmt.Errorf("failed looking up userID for sender %q: %w", senderUser, serr)
} }
if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, leaver) var domain spec.ServerName
if sender == nil {
// TODO: Currently a federated invite has no way of knowing the mxid_mapping of the inviter.
// Should we add the inviter's m.room.member event (with mxid_mapping) to invite_room_state to allow
// the invited user to leave via the inviter's server?
domain = roomID.Domain()
} else {
domain = sender.Domain()
}
if !r.Cfg.Matrix.IsLocalServerName(domain) {
return r.performFederatedRejectInvite(ctx, req, res, domain, eventID, *leaver)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{} accData := &userapi.QueryAccountDataResponse{}
@ -132,7 +143,7 @@ func (r *Leaver) performLeaveRoomByID(
StateToFetch: []gomatrixserverlib.StateKeyTuple{ StateToFetch: []gomatrixserverlib.StateKeyTuple{
{ {
EventType: spec.MRoomMember, EventType: spec.MRoomMember,
StateKey: string(leaver), StateKey: string(*leaver),
}, },
}, },
} }
@ -157,7 +168,7 @@ func (r *Leaver) performLeaveRoomByID(
} }
// Prepare the template for the leave event. // Prepare the template for the leave event.
senderIDString := string(leaver) senderIDString := string(*leaver)
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember, Type: spec.MRoomMember,
SenderID: senderIDString, SenderID: senderIDString,
@ -218,14 +229,14 @@ func (r *Leaver) performFederatedRejectInvite(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
inviteSender spec.UserID, eventID string, inviteDomain spec.ServerName, eventID string,
leaver spec.SenderID, leaver spec.SenderID,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// Ask the federation sender to perform a federated leave for us. // Ask the federation sender to perform a federated leave for us.
leaveReq := fsAPI.PerformLeaveRequest{ leaveReq := fsAPI.PerformLeaveRequest{
RoomID: req.RoomID, RoomID: req.RoomID,
UserID: req.Leaver.String(), UserID: req.Leaver.String(),
ServerNames: []spec.ServerName{inviteSender.Domain()}, ServerNames: []spec.ServerName{inviteDomain},
} }
leaveRes := fsAPI.PerformLeaveResponse{} leaveRes := fsAPI.PerformLeaveResponse{}
if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {

View file

@ -62,10 +62,13 @@ func (r *Upgrader) performRoomUpgrade(
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
return "", err return "", err
} else if senderID == nil {
util.GetLogger(ctx).WithField("userID", userID).WithField("roomID", *fullRoomID).Error("No senderID for user")
return "", fmt.Errorf("No sender ID for %s in %s", userID, *fullRoomID)
} }
// 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone)
if !r.userIsAuthorized(ctx, senderID, roomID) { if !r.userIsAuthorized(ctx, *senderID, roomID) {
return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")} return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")}
} }
@ -83,20 +86,20 @@ func (r *Upgrader) performRoomUpgrade(
} }
// Make the tombstone event // Make the tombstone event
tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, senderID, userID.Domain(), roomID, newRoomID) tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, *senderID, userID.Domain(), roomID, newRoomID)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Generate the initial events we need to send into the new room. This includes copied state events and bans // Generate the initial events we need to send into the new room. This includes copied state events and bans
// as well as the power level events needed to set up the room // as well as the power level events needed to set up the room
eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, senderID, roomID, roomVersion, tombstoneEvent) eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, *senderID, roomID, roomVersion, tombstoneEvent)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Send the setup events to the new room // Send the setup events to the new room
if pErr = r.sendInitialEvents(ctx, evTime, senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil { if pErr = r.sendInitialEvents(ctx, evTime, *senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil {
return "", pErr return "", pErr
} }
@ -111,17 +114,17 @@ func (r *Upgrader) performRoomUpgrade(
} }
// If the old room had a canonical alias event, it should be deleted in the old room // If the old room had a canonical alias event, it should be deleted in the old room
if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, senderID, userID.Domain(), roomID); pErr != nil { if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, *senderID, userID.Domain(), roomID); pErr != nil {
return "", pErr return "", pErr
} }
// 4. Move local aliases to the new room // 4. Move local aliases to the new room
if pErr = moveLocalAliases(ctx, roomID, newRoomID, senderID, userID, r.URSAPI); pErr != nil { if pErr = moveLocalAliases(ctx, roomID, newRoomID, *senderID, r.URSAPI); pErr != nil {
return "", pErr return "", pErr
} }
// 6. Restrict power levels in the old room // 6. Restrict power levels in the old room
if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, senderID, userID.Domain(), roomID); pErr != nil { if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, *senderID, userID.Domain(), roomID); pErr != nil {
return "", pErr return "", pErr
} }
@ -171,7 +174,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
} }
func moveLocalAliases(ctx context.Context, func moveLocalAliases(ctx context.Context,
roomID, newRoomID string, senderID spec.SenderID, userID spec.UserID, roomID, newRoomID string, senderID spec.SenderID,
URSAPI api.RoomserverInternalAPI, URSAPI api.RoomserverInternalAPI,
) (err error) { ) (err error) {
@ -181,17 +184,27 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to get old room aliases: %w", err) return fmt.Errorf("Failed to get old room aliases: %w", err)
} }
// TODO: this should be spec.RoomID further up the call stack
parsedNewRoomID, err := spec.NewRoomID(newRoomID)
if err != nil {
return err
}
for _, alias := range aliasRes.Aliases { for _, alias := range aliasRes.Aliases {
removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias} aliasFound, aliasRemoved, err := URSAPI.RemoveRoomAlias(ctx, senderID, alias)
removeAliasRes := api.RemoveRoomAliasResponse{} if err != nil {
if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil {
return fmt.Errorf("Failed to remove old room alias: %w", err) return fmt.Errorf("Failed to remove old room alias: %w", err)
} else if !aliasFound {
return fmt.Errorf("Failed to remove old room alias: alias not found, possible race")
} else if !aliasRemoved {
return fmt.Errorf("Failed to remove old alias")
} }
setAliasReq := api.SetRoomAliasRequest{UserID: userID.String(), Alias: alias, RoomID: newRoomID} aliasAlreadyExists, err := URSAPI.SetRoomAlias(ctx, senderID, *parsedNewRoomID, alias)
setAliasRes := api.SetRoomAliasResponse{} if err != nil {
if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil {
return fmt.Errorf("Failed to set new room alias: %w", err) return fmt.Errorf("Failed to set new room alias: %w", err)
} else if aliasAlreadyExists {
return fmt.Errorf("Failed to set new room alias: alias exists when it should have just been removed")
} }
} }
return nil return nil
@ -355,7 +368,16 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// in the create event (such as for the room types MSC). // in the create event (such as for the room types MSC).
newCreateContent := map[string]interface{}{} newCreateContent := map[string]interface{}{}
_ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent) _ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent)
newCreateContent["creator"] = string(senderID)
switch newVersion {
case gomatrixserverlib.RoomVersionV11:
// RoomVersionV11 removed the creator field from the create content: https://github.com/matrix-org/matrix-spec-proposals/pull/2175
// So if we are upgrading from pre v11, we need to remove the field.
delete(newCreateContent, "creator")
default:
newCreateContent["creator"] = senderID
}
newCreateContent["room_version"] = newVersion newCreateContent["room_version"] = newVersion
newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{ newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{
EventID: tombstoneEvent.EventID(), EventID: tombstoneEvent.EventID(),

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -47,6 +48,7 @@ type Queryer struct {
IsLocalServerName func(spec.ServerName) bool IsLocalServerName func(spec.ServerName) bool
ServerACLs *acls.ServerACLs ServerACLs *acls.ServerACLs
Cfg *config.Dendrite Cfg *config.Dendrite
FSAPI fsAPI.RoomserverFederationAPI
} }
func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
@ -163,6 +165,13 @@ func (r *Queryer) QueryStateAfterEvents(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, rejectedErr := r.DB.IsEventRejected(ctx, info.RoomNID, eventID)
if rejectedErr != nil {
return true
}
return isRejected
},
) )
if err != nil { if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
@ -228,6 +237,33 @@ func (r *Queryer) QueryMembershipForSenderID(
senderID spec.SenderID, senderID spec.SenderID,
response *api.QueryMembershipForUserResponse, response *api.QueryMembershipForUserResponse,
) error { ) error {
return r.queryMembershipForOptionalSenderID(ctx, roomID, &senderID, response)
}
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
return r.queryMembershipForOptionalSenderID(ctx, *roomID, senderID, response)
}
// Query membership information for provided sender ID and room ID
//
// If sender ID is nil, then act as if the provided sender is not a member of the room.
func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID spec.RoomID, senderID *spec.SenderID, response *api.QueryMembershipForUserResponse) error {
response.SenderID = senderID
info, err := r.DB.RoomInfo(ctx, roomID.String()) info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil { if err != nil {
return err return err
@ -238,11 +274,20 @@ func (r *Queryer) QueryMembershipForSenderID(
} }
response.RoomExists = true response.RoomExists = true
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID) if senderID == nil {
return nil
}
membershipEventNID, membershipState, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, *senderID)
if err != nil { if err != nil {
return err return err
} }
if membershipState == tables.MembershipStateInvite {
response.Membership = spec.Invite
response.IsInRoom = true
}
response.IsRoomForgotten = isRoomforgotten response.IsRoomForgotten = isRoomforgotten
if membershipEventNID == 0 { if membershipEventNID == 0 {
@ -266,70 +311,55 @@ func (r *Queryer) QueryMembershipForSenderID(
return err return err
} }
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
}
// QueryMembershipAtEvent returns the known memberships at a given event. // QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned // If the state before an event is not known, an empty list will be returned
// for that event instead. // for that event instead.
//
// Returned map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility.
func (r *Queryer) QueryMembershipAtEvent( func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context, ctx context.Context,
request *api.QueryMembershipAtEventRequest, roomID spec.RoomID,
response *api.QueryMembershipAtEventResponse, eventIDs []string,
) error { senderID spec.SenderID,
response.Membership = make(map[string]*types.HeaderedEvent) ) (map[string]*types.HeaderedEvent, error) {
info, err := r.DB.RoomInfo(ctx, roomID.String())
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil { if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err) return nil, fmt.Errorf("unable to get roomInfo: %w", err)
} }
if info == nil { if info == nil {
return fmt.Errorf("no roomInfo found") return nil, fmt.Errorf("no roomInfo found")
} }
// get the users stateKeyNID // get the users stateKeyNID
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID}) stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil { if err != nil {
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err) return nil, fmt.Errorf("unable to get stateKeyNIDs for %s: %w", senderID, err)
} }
if _, ok := stateKeyNIDs[request.UserID]; !ok { if _, ok := stateKeyNIDs[string(senderID)]; !ok {
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) return nil, fmt.Errorf("requested stateKeyNID for %s was not found", senderID)
} }
response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...) eventIDMembershipMap, err := r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[string(senderID)], info, eventIDs...)
switch err { switch err {
case nil: case nil:
return nil return eventIDMembershipMap, nil
case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
default: default:
return err return eventIDMembershipMap, err
} }
response.Membership = make(map[string]*types.HeaderedEvent) eventIDMembershipMap = make(map[string]*types.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r) stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, eventIDs, stateKeyNIDs[string(senderID)], r)
if err != nil { if err != nil {
return fmt.Errorf("unable to get state before event: %w", err) return eventIDMembershipMap, fmt.Errorf("unable to get state before event: %w", err)
} }
// If we only have one or less state entries, we can short circuit the below // If we only have one or less state entries, we can short circuit the below
// loop and avoid hitting the database // loop and avoid hitting the database
allStateEventNIDs := make(map[types.EventNID]types.StateEntry) allStateEventNIDs := make(map[types.EventNID]types.StateEntry)
for _, eventID := range request.EventIDs { for _, eventID := range eventIDs {
stateEntry := stateEntries[eventID] stateEntry := stateEntries[eventID]
for _, s := range stateEntry { for _, s := range stateEntry {
allStateEventNIDs[s.EventNID] = s allStateEventNIDs[s.EventNID] = s
@ -342,10 +372,10 @@ func (r *Queryer) QueryMembershipAtEvent(
} }
var memberships []types.Event var memberships []types.Event
for _, eventID := range request.EventIDs { for _, eventID := range eventIDs {
stateEntry, ok := stateEntries[eventID] stateEntry, ok := stateEntries[eventID]
if !ok || len(stateEntry) == 0 { if !ok || len(stateEntry) == 0 {
response.Membership[eventID] = nil eventIDMembershipMap[eventID] = nil
continue continue
} }
@ -359,7 +389,7 @@ func (r *Queryer) QueryMembershipAtEvent(
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
} }
if err != nil { if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err) return eventIDMembershipMap, fmt.Errorf("unable to get memberships at state: %w", err)
} }
// Iterate over all membership events we got. Given we only query the membership for // Iterate over all membership events we got. Given we only query the membership for
@ -367,13 +397,13 @@ func (r *Queryer) QueryMembershipAtEvent(
// a given event, overwrite any other existing membership events. // a given event, overwrite any other existing membership events.
for i := range memberships { for i := range memberships {
ev := memberships[i] ev := memberships[i]
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU} eventIDMembershipMap[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
} }
} }
} }
return nil return eventIDMembershipMap, nil
} }
// QueryMembershipsForRoom implements api.RoomserverInternalAPI // QueryMembershipsForRoom implements api.RoomserverInternalAPI
@ -416,7 +446,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return nil return nil
} }
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID) membershipEventNID, _, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID)
if err != nil { if err != nil {
return err return err
} }
@ -658,6 +688,13 @@ func (r *Queryer) QueryStateAndAuthChain(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, rejectedErr := r.DB.IsEventRejected(ctx, info.RoomNID, eventID)
if rejectedErr != nil {
return true
}
return isRejected
},
) )
if err != nil { if err != nil {
return err return err
@ -828,13 +865,20 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
return nil return nil
} }
func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { func (r *Queryer) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership) roomIDStrs, err := r.DB.GetRoomsByMembership(ctx, userID, desiredMembership)
if err != nil { if err != nil {
return err return nil, err
} }
res.RoomIDs = roomIDs roomIDs := make([]spec.RoomID, len(roomIDStrs))
return nil for i, roomIDStr := range roomIDStrs {
roomID, err := spec.NewRoomID(roomIDStr)
if err != nil {
return nil, err
}
roomIDs[i] = *roomID
}
return roomIDs, nil
} }
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
@ -877,7 +921,12 @@ func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersReq
} }
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") parsedUserID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return err
}
roomIDs, err := r.DB.GetRoomsByMembership(ctx, *parsedUserID, "join")
if err != nil { if err != nil {
return err return err
} }
@ -945,7 +994,7 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve
} }
func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) { func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) {
_, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID) _, _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID)
return isIn, err return isIn, err
} }
@ -974,6 +1023,20 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
return joinedUsers, nil return joinedUsers, nil
} }
func (r *Queryer) JoinedUserCount(ctx context.Context, roomID string) (int, error) {
info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return 0, err
}
if info == nil {
return 0, nil
}
// TODO: this can be further optimised by just using a SELECT COUNT query
nids, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
return len(nids), err
}
// nolint:gocyclo // nolint:gocyclo
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
// Look up if we know anything about the room. If it doesn't exist // Look up if we know anything about the room. If it doesn't exist
@ -993,21 +1056,26 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
} }
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
version, err := r.DB.GetRoomVersion(ctx, roomID.String()) version, err := r.DB.GetRoomVersion(ctx, roomID.String())
if err != nil { if err != nil {
return "", err return nil, err
} }
switch version { switch version {
case gomatrixserverlib.RoomVersionPseudoIDs: case gomatrixserverlib.RoomVersionPseudoIDs:
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
if err != nil { if err != nil {
return "", err return nil, err
} else if key == nil {
return nil, nil
} else {
senderID := spec.SenderID(spec.Base64Bytes(key).Encode())
return &senderID, nil
} }
return spec.SenderID(spec.Base64Bytes(key).Encode()), nil
default: default:
return spec.SenderID(userID.String()), nil senderID := spec.SenderID(userID.String())
return &senderID, nil
} }
} }

View file

@ -0,0 +1,530 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package query
import (
"context"
"encoding/json"
"fmt"
"sort"
fs "github.com/matrix-org/dendrite/federationapi/api"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
)
// Traverse the room hierarchy using the provided walker up to the provided limit,
// returning a new walker which can be used to fetch the next page.
//
// If limit is -1, this is treated as no limit, and the entire hierarchy will be traversed.
//
// If returned walker is nil, then there are no more rooms left to traverse. This method does not modify the provided walker, so it
// can be cached.
func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker roomserver.RoomHierarchyWalker, limit int) ([]fclient.RoomHierarchyRoom, *roomserver.RoomHierarchyWalker, error) {
if authorised, _ := authorised(ctx, querier, walker.Caller, walker.RootRoomID, nil); !authorised {
return nil, nil, roomserver.ErrRoomUnknownOrNotAllowed{Err: fmt.Errorf("room is unknown/forbidden")}
}
discoveredRooms := []fclient.RoomHierarchyRoom{}
// Copy unvisited and processed to avoid modifying original walker (which is typically in cache)
unvisited := make([]roomserver.RoomHierarchyWalkerQueuedRoom, len(walker.Unvisited))
copy(unvisited, walker.Unvisited)
processed := walker.Processed.Copy()
// Depth first -> stack data structure
for len(unvisited) > 0 {
if len(discoveredRooms) >= limit && limit != -1 {
break
}
// pop the stack
queuedRoom := unvisited[len(unvisited)-1]
unvisited = unvisited[:len(unvisited)-1]
// If this room has already been processed, skip.
// If this room exceeds the specified depth, skip.
if processed.Contains(queuedRoom.RoomID) || (walker.MaxDepth > 0 && queuedRoom.Depth > walker.MaxDepth) {
continue
}
// Mark this room as processed.
processed.Add(queuedRoom.RoomID)
// if this room is not a space room, skip.
var roomType string
create := stateEvent(ctx, querier, queuedRoom.RoomID, spec.MRoomCreate, "")
if create != nil {
var createContent gomatrixserverlib.CreateContent
err := json.Unmarshal(create.Content(), &createContent)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("create_content", create.Content()).Warn("failed to unmarshal m.room.create event")
}
roomType = createContent.RoomType
}
// Collect rooms/events to send back (either locally or fetched via federation)
var discoveredChildEvents []fclient.RoomHierarchyStrippedEvent
// If we know about this room and the caller is authorised (joined/world_readable) then pull
// events locally
roomExists := roomExists(ctx, querier, queuedRoom.RoomID)
if !roomExists {
// attempt to query this room over federation, as either we've never heard of it before
// or we've left it and hence are not authorised (but info may be exposed regardless)
fedRes := federatedRoomInfo(ctx, querier, walker.Caller, walker.SuggestedOnly, queuedRoom.RoomID, queuedRoom.Vias)
if fedRes != nil {
discoveredChildEvents = fedRes.Room.ChildrenState
discoveredRooms = append(discoveredRooms, fedRes.Room)
if len(fedRes.Children) > 0 {
discoveredRooms = append(discoveredRooms, fedRes.Children...)
}
// mark this room as a space room as the federated server responded.
// we need to do this so we add the children of this room to the unvisited stack
// as these children may be rooms we do know about.
roomType = spec.MSpace
}
} else if authorised, isJoinedOrInvited := authorised(ctx, querier, walker.Caller, queuedRoom.RoomID, queuedRoom.ParentRoomID); authorised {
// Get all `m.space.child` state events for this room
events, err := childReferences(ctx, querier, walker.SuggestedOnly, queuedRoom.RoomID)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("room_id", queuedRoom.RoomID).Error("failed to extract references for room")
continue
}
discoveredChildEvents = events
pubRoom := publicRoomsChunk(ctx, querier, queuedRoom.RoomID)
discoveredRooms = append(discoveredRooms, fclient.RoomHierarchyRoom{
PublicRoom: *pubRoom,
RoomType: roomType,
ChildrenState: events,
})
// don't walk children if the user is not joined/invited to the space
if !isJoinedOrInvited {
continue
}
} else {
// room exists but user is not authorised
continue
}
// don't walk the children
// if the parent is not a space room
if roomType != spec.MSpace {
continue
}
// For each referenced room ID in the child events being returned to the caller
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
// We need to invert the order here because the child events are lo->hi on the timestamp,
// so we need to ensure we pop in the same lo->hi order, which won't be the case if we
// insert the highest timestamp last in a stack.
for i := len(discoveredChildEvents) - 1; i >= 0; i-- {
spaceContent := struct {
Via []string `json:"via"`
}{}
ev := discoveredChildEvents[i]
_ = json.Unmarshal(ev.Content, &spaceContent)
childRoomID, err := spec.NewRoomID(ev.StateKey)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("invalid_room_id", ev.StateKey).WithField("parent_room_id", queuedRoom.RoomID).Warn("Invalid room ID in m.space.child state event")
} else {
unvisited = append(unvisited, roomserver.RoomHierarchyWalkerQueuedRoom{
RoomID: *childRoomID,
ParentRoomID: &queuedRoom.RoomID,
Depth: queuedRoom.Depth + 1,
Vias: spaceContent.Via,
})
}
}
}
if len(unvisited) == 0 {
// If no more rooms to walk, then don't return a walker for future pages
return discoveredRooms, nil, nil
} else {
// If there are more rooms to walk, then return a new walker to resume walking from (for querying more pages)
newWalker := roomserver.RoomHierarchyWalker{
RootRoomID: walker.RootRoomID,
Caller: walker.Caller,
SuggestedOnly: walker.SuggestedOnly,
MaxDepth: walker.MaxDepth,
Unvisited: unvisited,
Processed: processed,
}
return discoveredRooms, &newWalker, nil
}
}
// authorised returns true iff the user is joined this room or the room is world_readable
func authorised(ctx context.Context, querier *Queryer, caller types.DeviceOrServerName, roomID spec.RoomID, parentRoomID *spec.RoomID) (authed, isJoinedOrInvited bool) {
if clientCaller := caller.Device(); clientCaller != nil {
return authorisedUser(ctx, querier, clientCaller, roomID, parentRoomID)
} else {
return authorisedServer(ctx, querier, roomID, *caller.ServerName()), false
}
}
// authorisedServer returns true iff the server is joined this room or the room is world_readable, public, or knockable
func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID, callerServerName spec.ServerName) bool {
// Check history visibility / join rules first
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomHistoryVisibility,
StateKey: "",
}
joinRuleTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomJoinRules,
StateKey: "",
}
var queryRoomRes roomserver.QueryCurrentStateResponse
err := querier.QueryCurrentState(ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID.String(),
StateTuples: []gomatrixserverlib.StateKeyTuple{
hisVisTuple, joinRuleTuple,
},
}, &queryRoomRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to QueryCurrentState")
return false
}
hisVisEv := queryRoomRes.StateEvents[hisVisTuple]
if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" {
return true
}
}
// check if this room is a restricted room and if so, we need to check if the server is joined to an allowed room ID
// in addition to the actual room ID (but always do the actual one first as it's quicker in the common case)
allowJoinedToRoomIDs := []spec.RoomID{roomID}
joinRuleEv := queryRoomRes.StateEvents[joinRuleTuple]
if joinRuleEv != nil {
rule, ruleErr := joinRuleEv.JoinRule()
if ruleErr != nil {
util.GetLogger(ctx).WithError(ruleErr).WithField("parent_room_id", roomID).Warn("failed to get join rule")
return false
}
if rule == spec.Public || rule == spec.Knock {
return true
}
if rule == spec.Restricted {
allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, restrictedJoinRuleAllowedRooms(ctx, joinRuleEv)...)
}
}
// check if server is joined to any allowed room
for _, allowedRoomID := range allowJoinedToRoomIDs {
var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
err = querier.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: allowedRoomID.String(),
}, &queryRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom")
continue
}
for _, srv := range queryRes.ServerNames {
if srv == callerServerName {
return true
}
}
}
return false
}
// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable
// or if the room has a public or knock join rule.
// Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true.
func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi.Device, roomID spec.RoomID, parentRoomID *spec.RoomID) (authed bool, isJoinedOrInvited bool) {
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomHistoryVisibility,
StateKey: "",
}
joinRuleTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomJoinRules,
StateKey: "",
}
roomMemberTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomMember,
StateKey: clientCaller.UserID,
}
var queryRes roomserver.QueryCurrentStateResponse
err := querier.QueryCurrentState(ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID.String(),
StateTuples: []gomatrixserverlib.StateKeyTuple{
hisVisTuple, joinRuleTuple, roomMemberTuple,
},
}, &queryRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to QueryCurrentState")
return false, false
}
memberEv := queryRes.StateEvents[roomMemberTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == spec.Join || membership == spec.Invite {
return true, true
}
}
hisVisEv := queryRes.StateEvents[hisVisTuple]
if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" {
return true, false
}
}
joinRuleEv := queryRes.StateEvents[joinRuleTuple]
if parentRoomID != nil && joinRuleEv != nil {
var allowed bool
rule, ruleErr := joinRuleEv.JoinRule()
if ruleErr != nil {
util.GetLogger(ctx).WithError(ruleErr).WithField("parent_room_id", parentRoomID).Warn("failed to get join rule")
} else if rule == spec.Public || rule == spec.Knock {
allowed = true
} else if rule == spec.Restricted {
allowedRoomIDs := restrictedJoinRuleAllowedRooms(ctx, joinRuleEv)
// check parent is in the allowed set
for _, a := range allowedRoomIDs {
if *parentRoomID == a {
allowed = true
break
}
}
}
if allowed {
// ensure caller is joined to the parent room
var queryRes2 roomserver.QueryCurrentStateResponse
err = querier.QueryCurrentState(ctx, &roomserver.QueryCurrentStateRequest{
RoomID: parentRoomID.String(),
StateTuples: []gomatrixserverlib.StateKeyTuple{
roomMemberTuple,
},
}, &queryRes2)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("parent_room_id", parentRoomID).Warn("failed to check user is joined to parent room")
} else {
memberEv = queryRes2.StateEvents[roomMemberTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == spec.Join {
return true, false
}
}
}
}
}
return false, false
}
// helper function to fetch a state event
func stateEvent(ctx context.Context, querier *Queryer, roomID spec.RoomID, evType, stateKey string) *types.HeaderedEvent {
var queryRes roomserver.QueryCurrentStateResponse
tuple := gomatrixserverlib.StateKeyTuple{
EventType: evType,
StateKey: stateKey,
}
err := querier.QueryCurrentState(ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID.String(),
StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
}, &queryRes)
if err != nil {
return nil
}
return queryRes.StateEvents[tuple]
}
// returns true if the current server is participating in the provided room
func roomExists(ctx context.Context, querier *Queryer, roomID spec.RoomID) bool {
var queryRes roomserver.QueryServerJoinedToRoomResponse
err := querier.QueryServerJoinedToRoom(ctx, &roomserver.QueryServerJoinedToRoomRequest{
RoomID: roomID.String(),
ServerName: querier.Cfg.Global.ServerName,
}, &queryRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to QueryServerJoinedToRoom")
return false
}
// if the room exists but we aren't in the room then we might have stale data so we want to fetch
// it fresh via federation
return queryRes.RoomExists && queryRes.IsInRoom
}
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
// unsuccessful.
func federatedRoomInfo(ctx context.Context, querier *Queryer, caller types.DeviceOrServerName, suggestedOnly bool, roomID spec.RoomID, vias []string) *fclient.RoomHierarchyResponse {
// only do federated requests for client requests
if caller.Device() == nil {
return nil
}
resp, ok := querier.Cache.GetRoomHierarchy(roomID.String())
if ok {
util.GetLogger(ctx).Debugf("Returning cached response for %s", roomID)
return &resp
}
util.GetLogger(ctx).Debugf("Querying %s via %+v", roomID, vias)
innerCtx := context.Background()
// query more of the spaces graph using these servers
for _, serverName := range vias {
if serverName == string(querier.Cfg.Global.ServerName) {
continue
}
res, err := querier.FSAPI.RoomHierarchies(innerCtx, querier.Cfg.Global.ServerName, spec.ServerName(serverName), roomID.String(), suggestedOnly)
if err != nil {
util.GetLogger(ctx).WithError(err).Warnf("failed to call RoomHierarchies on server %s", serverName)
continue
}
// ensure nil slices are empty as we send this to the client sometimes
if res.Room.ChildrenState == nil {
res.Room.ChildrenState = []fclient.RoomHierarchyStrippedEvent{}
}
for i := 0; i < len(res.Children); i++ {
child := res.Children[i]
if child.ChildrenState == nil {
child.ChildrenState = []fclient.RoomHierarchyStrippedEvent{}
}
res.Children[i] = child
}
querier.Cache.StoreRoomHierarchy(roomID.String(), res)
return &res
}
return nil
}
// references returns all child references pointing to or from this room.
func childReferences(ctx context.Context, querier *Queryer, suggestedOnly bool, roomID spec.RoomID) ([]fclient.RoomHierarchyStrippedEvent, error) {
createTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomCreate,
StateKey: "",
}
var res roomserver.QueryCurrentStateResponse
err := querier.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{
RoomID: roomID.String(),
AllowWildcards: true,
StateTuples: []gomatrixserverlib.StateKeyTuple{
createTuple, {
EventType: spec.MSpaceChild,
StateKey: "*",
},
},
}, &res)
if err != nil {
return nil, err
}
// don't return any child refs if the room is not a space room
if create := res.StateEvents[createTuple]; create != nil {
var createContent gomatrixserverlib.CreateContent
err := json.Unmarshal(create.Content(), &createContent)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("create_content", create.Content()).Warn("failed to unmarshal m.room.create event")
}
roomType := createContent.RoomType
if roomType != spec.MSpace {
return []fclient.RoomHierarchyStrippedEvent{}, nil
}
}
delete(res.StateEvents, createTuple)
el := make([]fclient.RoomHierarchyStrippedEvent, 0, len(res.StateEvents))
for _, ev := range res.StateEvents {
content := gjson.ParseBytes(ev.Content())
// only return events that have a `via` key as per MSC1772
// else we'll incorrectly walk redacted events (as the link
// is in the state_key)
if content.Get("via").Exists() {
strip := stripped(ev.PDU)
if strip == nil {
continue
}
// if suggested only and this child isn't suggested, skip it.
// if suggested only = false we include everything so don't need to check the content.
if suggestedOnly && !content.Get("suggested").Bool() {
continue
}
el = append(el, *strip)
}
}
// sort by origin_server_ts as per MSC2946
sort.Slice(el, func(i, j int) bool {
return el[i].OriginServerTS < el[j].OriginServerTS
})
return el, nil
}
// fetch public room information for provided room
func publicRoomsChunk(ctx context.Context, querier *Queryer, roomID spec.RoomID) *fclient.PublicRoom {
pubRooms, err := roomserver.PopulatePublicRooms(ctx, []string{roomID.String()}, querier)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to PopulatePublicRooms")
return nil
}
if len(pubRooms) == 0 {
return nil
}
return &pubRooms[0]
}
func stripped(ev gomatrixserverlib.PDU) *fclient.RoomHierarchyStrippedEvent {
if ev.StateKey() == nil {
return nil
}
return &fclient.RoomHierarchyStrippedEvent{
Type: ev.Type(),
StateKey: *ev.StateKey(),
Content: ev.Content(),
Sender: string(ev.SenderID()),
OriginServerTS: ev.OriginServerTS(),
}
}
// given join_rule event, return list of rooms where membership of that room allows joining.
func restrictedJoinRuleAllowedRooms(ctx context.Context, joinRuleEv *types.HeaderedEvent) (allows []spec.RoomID) {
rule, _ := joinRuleEv.JoinRule()
if rule != spec.Restricted {
return nil
}
var jrContent gomatrixserverlib.JoinRuleContent
if err := json.Unmarshal(joinRuleEv.Content(), &jrContent); err != nil {
util.GetLogger(ctx).Warnf("failed to check join_rule on room %s: %s", joinRuleEv.RoomID().String(), err)
return nil
}
for _, allow := range jrContent.Allow {
if allow.Type == spec.MRoomMembership {
allowedRoomID, err := spec.NewRoomID(allow.RoomID)
if err != nil {
util.GetLogger(ctx).Warnf("invalid room ID '%s' found in join_rule on room %s: %s", allow.RoomID, joinRuleEv.RoomID().String(), err)
} else {
allows = append(allows, *allowedRoomID)
}
}
}
return
}

View file

@ -28,10 +28,13 @@ import (
) )
// NewInternalAPI returns a concrete implementation of the internal API. // NewInternalAPI returns a concrete implementation of the internal API.
//
// Many of the methods provided by this API depend on access to a federation API, and so
// you may wish to call `SetFederationAPI` on the returned struct to avoid nil-dereference errors.
func NewInternalAPI( func NewInternalAPI(
processContext *process.ProcessContext, processContext *process.ProcessContext,
cfg *config.Dendrite, cfg *config.Dendrite,
cm sqlutil.Connections, cm *sqlutil.Connections,
natsInstance *jetstream.NATSInstance, natsInstance *jetstream.NATSInstance,
caches caching.RoomServerCaches, caches caching.RoomServerCaches,
enableMetrics bool, enableMetrics bool,

View file

@ -45,6 +45,7 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (bool, error)
} }
type StateResolution struct { type StateResolution struct {
@ -1066,6 +1067,13 @@ func (v *StateResolution) resolveConflictsV2(
func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, err := v.db.IsEventRejected(ctx, v.roomInfo.RoomNID, eventID)
if err != nil {
return true
}
return isRejected
},
) )
}() }()

View file

@ -133,7 +133,7 @@ type Database interface {
// in this room, along a boolean set to true if the user is still in this room, // in this room, along a boolean set to true if the user is still in this room,
// false if not. // false if not.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, membershipNID tables.MembershipState, stillInRoom, isRoomForgotten bool, err error)
// Lookup the membership event numeric IDs for all user that are or have // Lookup the membership event numeric IDs for all user that are or have
// been members of a given room. Only lookup events of "join" membership if // been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true. // joinOnly is set to true.
@ -158,7 +158,7 @@ type Database interface {
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error)
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) GetRoomsByMembership(ctx context.Context, userID spec.UserID, membership string) ([]string, error)
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)

View file

@ -37,7 +37,7 @@ type Database struct {
} }
// Open a postgres database. // Open a postgres database.
func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database var d Database
var err error var err error
db, writer, err := conMan.Connection(dbProperties) db, writer, err := conMan.Connection(dbProperties)

View file

@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
const selectAllUserRoomPublicKeyForUserSQL = `SELECT room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt
selectAllUserRoomPublicKeysForUser *sql.Stmt
} }
func CreateUserRoomKeysTable(db *sql.DB) error { func CreateUserRoomKeysTable(db *sql.DB) error {
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL},
{&s.selectAllUserRoomPublicKeysForUser, selectAllUserRoomPublicKeyForUserSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -150,3 +154,24 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectAllUserRoomPublicKeysForUser)
rows, err := stmt.QueryContext(ctx, userNID)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
resultMap := make(map[types.RoomNID]ed25519.PublicKey)
var roomNID types.RoomNID
var pubkey ed25519.PublicKey
for rows.Next() {
if err = rows.Scan(&roomNID, &pubkey); err != nil {
return nil, err
}
resultMap[roomNID] = pubkey
}
return resultMap, err
}

View file

@ -250,3 +250,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }
func (u *RoomUpdater) IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (bool, error) {
return u.d.IsEventRejected(ctx, roomNID, eventID)
}

View file

@ -491,14 +491,14 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
}) })
} }
func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, membershipState tables.MembershipState, stillInRoom, isRoomforgotten bool, err error) {
var requestSenderUserNID types.EventStateKeyNID var requestSenderUserNID types.EventStateKeyNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID)) requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID))
return err return err
}) })
if err != nil { if err != nil {
return 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err) return 0, 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err)
} }
senderMembershipEventNID, senderMembership, isRoomforgotten, err := senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
@ -507,12 +507,12 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// The user has never been a member of that room // The user has never been a member of that room
return 0, false, false, nil return 0, 0, false, false, nil
} else if err != nil { } else if err != nil {
return return
} }
return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil return senderMembershipEventNID, senderMembership, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil
} }
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
@ -696,8 +696,8 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver
return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
} }
roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID()) roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID().String())
cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID()) cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID().String())
// if we found both, the roomNID and version in our cache, no need to query the database // if we found both, the roomNID and version in our cache, no need to query the database
if nidOK && versionOK { if nidOK && versionOK {
return &types.RoomInfo{ return &types.RoomInfo{
@ -707,14 +707,14 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver
} }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID().String(), roomVersion)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
}) })
if roomVersion != "" { if roomVersion != "" {
d.Cache.StoreRoomVersion(event.RoomID(), roomVersion) d.Cache.StoreRoomVersion(event.RoomID().String(), roomVersion)
} }
return &types.RoomInfo{ return &types.RoomInfo{
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -1026,24 +1026,19 @@ func (d *EventDatabase) MaybeRedactEvent(
case validated || redactedEvent == nil || redactionEvent == nil: case validated || redactedEvent == nil || redactionEvent == nil:
// we've seen this redaction before or there is nothing to redact // we've seen this redaction before or there is nothing to redact
return nil return nil
case redactedEvent.RoomID() != redactionEvent.RoomID(): case redactedEvent.RoomID().String() != redactionEvent.RoomID().String():
// redactions across rooms aren't allowed // redactions across rooms aren't allowed
ignoreRedaction = true ignoreRedaction = true
return nil return nil
} }
var validRoomID *spec.RoomID
validRoomID, err = spec.NewRoomID(redactedEvent.RoomID())
if err != nil {
return err
}
sender1Domain := "" sender1Domain := ""
sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEvent.SenderID()) sender1, err1 := querier.QueryUserIDForSender(ctx, redactedEvent.RoomID(), redactedEvent.SenderID())
if err1 == nil { if err1 == nil {
sender1Domain = string(sender1.Domain()) sender1Domain = string(sender1.Domain())
} }
sender2Domain := "" sender2Domain := ""
sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) sender2, err2 := querier.QueryUserIDForSender(ctx, redactedEvent.RoomID(), redactionEvent.SenderID())
if err2 == nil { if err2 == nil {
sender2Domain = string(sender2.Domain()) sender2Domain = string(sender2.Domain())
} }
@ -1347,7 +1342,7 @@ func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evTy
} }
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { func (d *Database) GetRoomsByMembership(ctx context.Context, userID spec.UserID, membership string) ([]string, error) {
var membershipState tables.MembershipState var membershipState tables.MembershipState
switch membership { switch membership {
case "join": case "join":
@ -1361,17 +1356,73 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
default: default:
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
} }
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
// Convert provided user ID to NID
userNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID.String())
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} else {
return nil, fmt.Errorf("SelectEventStateKeyNID: cannot map user ID to state key NIDs: %w", err)
} }
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
} }
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
// Use this NID to fetch all associated room keys (for pseudo ID rooms)
roomKeyMap, err := d.UserRoomKeyTable.SelectAllPublicKeysForUser(ctx, nil, userNID)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) if err == sql.ErrNoRows {
roomKeyMap = map[types.RoomNID]ed25519.PublicKey{}
} else {
return nil, fmt.Errorf("SelectAllPublicKeysForUser: could not select user room public keys for user: %w", err)
}
} }
var eventStateKeyNIDs []types.EventStateKeyNID
// If there are room keys (i.e. this user is in pseudo ID rooms), then gather the appropriate NIDs
if len(roomKeyMap) != 0 {
// Convert keys to string representation
userRoomKeys := make([]string, len(roomKeyMap))
i := 0
for _, key := range roomKeyMap {
userRoomKeys[i] = spec.Base64Bytes(key).Encode()
i += 1
}
// Convert the string representation to its NID
pseudoIDStateKeys, sqlErr := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userRoomKeys)
if sqlErr != nil {
if sqlErr == sql.ErrNoRows {
pseudoIDStateKeys = map[string]types.EventStateKeyNID{}
} else {
return nil, fmt.Errorf("BulkSelectEventStateKeyNID: could not select state keys for public room keys: %w", err)
}
}
// Collect all NIDs together
eventStateKeyNIDs = make([]types.EventStateKeyNID, len(pseudoIDStateKeys)+1)
eventStateKeyNIDs[0] = userNID
i = 1
for _, nid := range pseudoIDStateKeys {
eventStateKeyNIDs[i] = nid
i += 1
}
} else {
// If there are no room keys (so no pseudo ID rooms), we only need to care about the user ID NID.
eventStateKeyNIDs = []types.EventStateKeyNID{userNID}
}
// Fetch rooms that match membership for each NID
roomNIDs := []types.RoomNID{}
for _, nid := range eventStateKeyNIDs {
var roomNIDsChunk []types.RoomNID
roomNIDsChunk, err = d.MembershipTable.SelectRoomsWithMembership(ctx, nil, nid, membershipState)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
}
roomNIDs = append(roomNIDs, roomNIDsChunk...)
}
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
@ -1466,7 +1517,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
result[i] = tables.StrippedEvent{ result[i] = tables.StrippedEvent{
EventType: ev.Type(), EventType: ev.Type(),
RoomID: ev.RoomID(), RoomID: ev.RoomID().String(),
StateKey: *ev.StateKey(), StateKey: *ev.StateKey(),
ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}), ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}),
} }

View file

@ -36,7 +36,7 @@ type Database struct {
} }
// Open a sqlite database. // Open a sqlite database.
func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database var d Database
var err error var err error
db, writer, err := conMan.Connection(dbProperties) db, writer, err := conMan.Connection(dbProperties)

View file

@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
const selectAllUserRoomPublicKeyForUserSQL = `SELECT room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
db *sql.DB db *sql.DB
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt
selectAllUserRoomPublicKeysForUser *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime //selectUserNIDsStmt *sql.Stmt //prepared at runtime
} }
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectAllUserRoomPublicKeysForUser, selectAllUserRoomPublicKeyForUserSQL},
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
@ -165,3 +169,24 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectAllUserRoomPublicKeysForUser)
rows, err := stmt.QueryContext(ctx, userNID)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
resultMap := make(map[types.RoomNID]ed25519.PublicKey)
var roomNID types.RoomNID
var pubkey ed25519.PublicKey
for rows.Next() {
if err = rows.Scan(&roomNID, &pubkey); err != nil {
return nil, err
}
resultMap[roomNID] = pubkey
}
return resultMap, err
}

View file

@ -29,7 +29,7 @@ import (
) )
// Open opens a database connection. // Open opens a database connection.
func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(ctx, conMan, dbProperties, cache) return sqlite3.Open(ctx, conMan, dbProperties, cache)

View file

@ -25,7 +25,7 @@ import (
) )
// NewPublicRoomsServerDatabase opens a database connection. // NewPublicRoomsServerDatabase opens a database connection.
func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(ctx, conMan, dbProperties, cache) return sqlite3.Open(ctx, conMan, dbProperties, cache)

View file

@ -198,6 +198,8 @@ type UserRoomKeys interface {
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
// If a senderKey can't be found, it is omitted in the result. // If a senderKey can't be found, it is omitted in the result.
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
// SelectAllPublicKeysForUser returns all known public keys for a user. Returns a map from room NID -> public key
SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -22,7 +22,9 @@ import (
"strings" "strings"
"sync" "sync"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
) )
@ -336,3 +338,36 @@ func (r *RoomInfo) CopyFrom(r2 *RoomInfo) {
} }
var ErrorInvalidRoomInfo = fmt.Errorf("room info is invalid") var ErrorInvalidRoomInfo = fmt.Errorf("room info is invalid")
// Struct to represent a device or a server name.
//
// May be used to designate a caller for functions that can be called
// by a client (device) or by a server (server name).
//
// Exactly 1 of Device() and ServerName() will return a non-nil result.
type DeviceOrServerName struct {
device *userapi.Device
serverName *spec.ServerName
}
func NewDeviceNotServerName(device userapi.Device) DeviceOrServerName {
return DeviceOrServerName{
device: &device,
serverName: nil,
}
}
func NewServerNameNotDevice(serverName spec.ServerName) DeviceOrServerName {
return DeviceOrServerName{
device: nil,
serverName: &serverName,
}
}
func (s *DeviceOrServerName) Device() *userapi.Device {
return s.device
}
func (s *DeviceOrServerName) ServerName() *spec.ServerName {
return s.serverName
}

View file

@ -1,12 +1,22 @@
package config package config
import (
"fmt"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type RoomServer struct { type RoomServer struct {
Matrix *Global `yaml:"-"` Matrix *Global `yaml:"-"`
DefaultRoomVersion gomatrixserverlib.RoomVersion `yaml:"default_room_version,omitempty"`
Database DatabaseOptions `yaml:"database,omitempty"` Database DatabaseOptions `yaml:"database,omitempty"`
} }
func (c *RoomServer) Defaults(opts DefaultOpts) { func (c *RoomServer) Defaults(opts DefaultOpts) {
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
if opts.Generate { if opts.Generate {
if !opts.SingleDatabase { if !opts.SingleDatabase {
c.Database.ConnectionString = "file:roomserver.db" c.Database.ConnectionString = "file:roomserver.db"
@ -18,4 +28,10 @@ func (c *RoomServer) Verify(configErrs *ConfigErrors) {
if c.Matrix.DatabaseOptions.ConnectionString == "" { if c.Matrix.DatabaseOptions.ConnectionString == "" {
checkNotEmpty(configErrs, "room_server.database.connection_string", string(c.Database.ConnectionString)) checkNotEmpty(configErrs, "room_server.database.connection_string", string(c.Database.ConnectionString))
} }
if !gomatrixserverlib.KnownRoomVersion(c.DefaultRoomVersion) {
configErrs.Add(fmt.Sprintf("invalid value for config key 'room_server.default_room_version': unsupported room version: %q", c.DefaultRoomVersion))
} else if !gomatrixserverlib.StableRoomVersion(c.DefaultRoomVersion) {
log.Warnf("WARNING: Provided default room version %q is unstable", c.DefaultRoomVersion)
}
} }

View file

@ -58,7 +58,7 @@ func (m *Monolith) AddAllPublicRoutes(
processCtx *process.ProcessContext, processCtx *process.ProcessContext,
cfg *config.Dendrite, cfg *config.Dendrite,
routers httputil.Routers, routers httputil.Routers,
cm sqlutil.Connections, cm *sqlutil.Connections,
natsInstance *jetstream.NATSInstance, natsInstance *jetstream.NATSInstance,
caches *caching.Caches, caches *caching.Caches,
enableMetrics bool, enableMetrics bool,

View file

@ -105,7 +105,7 @@ func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsRespons
// Enable this MSC // Enable this MSC
func Enable( func Enable(
cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
) error { ) error {
db, err := NewDatabase(cm, &cfg.MSCs.Database) db, err := NewDatabase(cm, &cfg.MSCs.Database)
@ -271,7 +271,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo
event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
} }
if rc.req.RoomID == "" && event != nil { if rc.req.RoomID == "" && event != nil {
rc.req.RoomID = event.RoomID() rc.req.RoomID = event.RoomID().String()
} }
if event == nil || !rc.authorisedToSeeEvent(event) { if event == nil || !rc.authorisedToSeeEvent(event) {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
@ -526,7 +526,7 @@ func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool {
// make sure the server is in this room // make sure the server is in this room
var res fs.QueryJoinedHostServerNamesInRoomResponse var res fs.QueryJoinedHostServerNamesInRoomResponse
err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: event.RoomID(), RoomID: event.RoomID().String(),
}, &res) }, &res)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryJoinedHostServerNamesInRoom") util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryJoinedHostServerNamesInRoom")
@ -545,7 +545,7 @@ func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool {
// TODO: This does not honour m.room.create content // TODO: This does not honour m.room.create content
var queryMembershipRes roomserver.QueryMembershipForUserResponse var queryMembershipRes roomserver.QueryMembershipForUserResponse
err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{ err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
RoomID: event.RoomID(), RoomID: event.RoomID().String(),
UserID: rc.userID, UserID: rc.userID,
}, &queryMembershipRes) }, &queryMembershipRes)
if err != nil { if err != nil {
@ -612,7 +612,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *types.HeaderedEvent {
// inject all the events into the roomserver then return the event in question // inject all the events into the roomserver then return the event in question
rc.injectResponseToRoomserver(queryRes) rc.injectResponseToRoomserver(queryRes)
for _, ev := range queryRes.ParsedEvents { for _, ev := range queryRes.ParsedEvents {
if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() { if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID().String() {
return &types.HeaderedEvent{PDU: ev} return &types.HeaderedEvent{PDU: ev}
} }
} }
@ -629,7 +629,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *types.HeaderedEvent {
} }
} }
} }
if rc.req.RoomID == event.RoomID() { if rc.req.RoomID == event.RoomID().String() {
return event return event
} }
return nil return nil

View file

@ -59,14 +59,14 @@ type DB struct {
} }
// NewDatabase loads the database for msc2836 // NewDatabase loads the database for msc2836
func NewDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) { func NewDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
if dbOpts.ConnectionString.IsPostgres() { if dbOpts.ConnectionString.IsPostgres() {
return newPostgresDatabase(conMan, dbOpts) return newPostgresDatabase(conMan, dbOpts)
} }
return newSQLiteDatabase(conMan, dbOpts) return newSQLiteDatabase(conMan, dbOpts)
} }
func newPostgresDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) { func newPostgresDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{} d := DB{}
var err error var err error
if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil { if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil {
@ -144,7 +144,7 @@ func newPostgresDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOpti
return &d, err return &d, err
} }
func newSQLiteDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) { func newSQLiteDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{} d := DB{}
var err error var err error
if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil { if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil {
@ -239,7 +239,7 @@ func (p *DB) StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error {
return err return err
} }
util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType) util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType)
_, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID(), count, base64.RawStdEncoding.EncodeToString(hash), 0) _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID().String(), count, base64.RawStdEncoding.EncodeToString(hash), 0)
return err return err
}) })
} }

View file

@ -1,744 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package msc2946 'Spaces Summary' implements https://github.com/matrix-org/matrix-doc/pull/2946
package msc2946
import (
"context"
"encoding/json"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
fs "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
)
const (
ConstCreateEventContentKey = "type"
ConstCreateEventContentValueSpace = "m.space"
ConstSpaceChildEventType = "m.space.child"
ConstSpaceParentEventType = "m.space.parent"
)
type MSC2946ClientResponse struct {
Rooms []fclient.MSC2946Room `json:"rooms"`
NextBatch string `json:"next_batch,omitempty"`
}
// Enable this MSC
func Enable(
cfg *config.Dendrite, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache,
) error {
clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, cfg.Global.ServerName), httputil.WithAllowGuests())
routers.Client.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
routers.Client.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
fedAPI := httputil.MakeExternalAPI(
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
fedReq, errResp := fclient.VerifyHTTPRequest(
req, time.Now(), cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
)
if fedReq == nil {
return errResp
}
// Extract the room ID from the request. Sanity check request data.
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID := params["roomID"]
return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, cfg.Global.ServerName)
},
)
routers.Federation.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
routers.Federation.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
return nil
}
func federatedSpacesHandler(
ctx context.Context, fedReq *fclient.FederationRequest, roomID string,
cache caching.SpaceSummaryRoomsCache,
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
thisServer spec.ServerName,
) util.JSONResponse {
u, err := url.Parse(fedReq.RequestURI())
if err != nil {
return util.JSONResponse{
Code: 400,
JSON: spec.InvalidParam("bad request uri"),
}
}
w := walker{
rootRoomID: roomID,
serverName: fedReq.Origin(),
thisServer: thisServer,
ctx: ctx,
cache: cache,
suggestedOnly: u.Query().Get("suggested_only") == "true",
limit: 1000,
// The main difference is that it does not recurse into spaces and does not support pagination.
// This is somewhat equivalent to a Client-Server request with a max_depth=1.
maxDepth: 1,
rsAPI: rsAPI,
fsAPI: fsAPI,
// inline cache as we don't have pagination in federation mode
paginationCache: make(map[string]paginationInfo),
}
return w.walk()
}
func spacesHandler(
rsAPI roomserver.RoomserverInternalAPI,
fsAPI fs.FederationInternalAPI,
cache caching.SpaceSummaryRoomsCache,
thisServer spec.ServerName,
) func(*http.Request, *userapi.Device) util.JSONResponse {
// declared outside the returned handler so it persists between calls
// TODO: clear based on... time?
paginationCache := make(map[string]paginationInfo)
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
// Extract the room ID from the request. Sanity check request data.
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID := params["roomID"]
w := walker{
suggestedOnly: req.URL.Query().Get("suggested_only") == "true",
limit: parseInt(req.URL.Query().Get("limit"), 1000),
maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1),
paginationToken: req.URL.Query().Get("from"),
rootRoomID: roomID,
caller: device,
thisServer: thisServer,
ctx: req.Context(),
cache: cache,
rsAPI: rsAPI,
fsAPI: fsAPI,
paginationCache: paginationCache,
}
return w.walk()
}
}
type paginationInfo struct {
processed set
unvisited []roomVisit
}
type walker struct {
rootRoomID string
caller *userapi.Device
serverName spec.ServerName
thisServer spec.ServerName
rsAPI roomserver.RoomserverInternalAPI
fsAPI fs.FederationInternalAPI
ctx context.Context
cache caching.SpaceSummaryRoomsCache
suggestedOnly bool
limit int
maxDepth int
paginationToken string
paginationCache map[string]paginationInfo
mu sync.Mutex
}
func (w *walker) newPaginationCache() (string, paginationInfo) {
p := paginationInfo{
processed: make(set),
unvisited: nil,
}
tok := uuid.NewString()
return tok, p
}
func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo {
w.mu.Lock()
defer w.mu.Unlock()
p := w.paginationCache[paginationToken]
return &p
}
func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) {
w.mu.Lock()
defer w.mu.Unlock()
w.paginationCache[paginationToken] = cache
}
type roomVisit struct {
roomID string
parentRoomID string
depth int
vias []string // vias to query this room by
}
func (w *walker) walk() util.JSONResponse {
if authorised, _ := w.authorised(w.rootRoomID, ""); !authorised {
if w.caller != nil {
// CS API format
return util.JSONResponse{
Code: 403,
JSON: spec.Forbidden("room is unknown/forbidden"),
}
} else {
// SS API format
return util.JSONResponse{
Code: 404,
JSON: spec.NotFound("room is unknown/forbidden"),
}
}
}
var discoveredRooms []fclient.MSC2946Room
var cache *paginationInfo
if w.paginationToken != "" {
cache = w.loadPaginationCache(w.paginationToken)
if cache == nil {
return util.JSONResponse{
Code: 400,
JSON: spec.InvalidParam("invalid from"),
}
}
} else {
tok, c := w.newPaginationCache()
cache = &c
w.paginationToken = tok
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
c.unvisited = append(c.unvisited, roomVisit{
roomID: w.rootRoomID,
parentRoomID: "",
depth: 0,
})
}
processed := cache.processed
unvisited := cache.unvisited
// Depth first -> stack data structure
for len(unvisited) > 0 {
if len(discoveredRooms) >= w.limit {
break
}
// pop the stack
rv := unvisited[len(unvisited)-1]
unvisited = unvisited[:len(unvisited)-1]
// If this room has already been processed, skip.
// If this room exceeds the specified depth, skip.
if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) {
continue
}
// Mark this room as processed.
processed.set(rv.roomID)
// if this room is not a space room, skip.
var roomType string
create := w.stateEvent(rv.roomID, spec.MRoomCreate, "")
if create != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
}
// Collect rooms/events to send back (either locally or fetched via federation)
var discoveredChildEvents []fclient.MSC2946StrippedEvent
// If we know about this room and the caller is authorised (joined/world_readable) then pull
// events locally
roomExists := w.roomExists(rv.roomID)
if !roomExists {
// attempt to query this room over federation, as either we've never heard of it before
// or we've left it and hence are not authorised (but info may be exposed regardless)
fedRes := w.federatedRoomInfo(rv.roomID, rv.vias)
if fedRes != nil {
discoveredChildEvents = fedRes.Room.ChildrenState
discoveredRooms = append(discoveredRooms, fedRes.Room)
if len(fedRes.Children) > 0 {
discoveredRooms = append(discoveredRooms, fedRes.Children...)
}
// mark this room as a space room as the federated server responded.
// we need to do this so we add the children of this room to the unvisited stack
// as these children may be rooms we do know about.
roomType = ConstCreateEventContentValueSpace
}
} else if authorised, isJoinedOrInvited := w.authorised(rv.roomID, rv.parentRoomID); authorised {
// Get all `m.space.child` state events for this room
events, err := w.childReferences(rv.roomID)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room")
continue
}
discoveredChildEvents = events
pubRoom := w.publicRoomsChunk(rv.roomID)
discoveredRooms = append(discoveredRooms, fclient.MSC2946Room{
PublicRoom: *pubRoom,
RoomType: roomType,
ChildrenState: events,
})
// don't walk children if the user is not joined/invited to the space
if !isJoinedOrInvited {
continue
}
} else {
// room exists but user is not authorised
continue
}
// don't walk the children
// if the parent is not a space room
if roomType != ConstCreateEventContentValueSpace {
continue
}
// For each referenced room ID in the child events being returned to the caller
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
// We need to invert the order here because the child events are lo->hi on the timestamp,
// so we need to ensure we pop in the same lo->hi order, which won't be the case if we
// insert the highest timestamp last in a stack.
for i := len(discoveredChildEvents) - 1; i >= 0; i-- {
spaceContent := struct {
Via []string `json:"via"`
}{}
ev := discoveredChildEvents[i]
_ = json.Unmarshal(ev.Content, &spaceContent)
unvisited = append(unvisited, roomVisit{
roomID: ev.StateKey,
parentRoomID: rv.roomID,
depth: rv.depth + 1,
vias: spaceContent.Via,
})
}
}
if len(unvisited) > 0 {
// we still have more rooms so we need to send back a pagination token,
// we probably hit a room limit
cache.processed = processed
cache.unvisited = unvisited
w.storePaginationCache(w.paginationToken, *cache)
} else {
// clear the pagination token so we don't send it back to the client
// Note we do NOT nuke the cache just in case this response is lost
// and the client retries it.
w.paginationToken = ""
}
if w.caller != nil {
// return CS API format
return util.JSONResponse{
Code: 200,
JSON: MSC2946ClientResponse{
Rooms: discoveredRooms,
NextBatch: w.paginationToken,
},
}
}
// return SS API format
// the first discovered room will be the room asked for, and subsequent ones the depth=1 children
if len(discoveredRooms) == 0 {
return util.JSONResponse{
Code: 404,
JSON: spec.NotFound("room is unknown/forbidden"),
}
}
return util.JSONResponse{
Code: 200,
JSON: fclient.MSC2946SpacesResponse{
Room: discoveredRooms[0],
Children: discoveredRooms[1:],
},
}
}
func (w *walker) stateEvent(roomID, evType, stateKey string) *types.HeaderedEvent {
var queryRes roomserver.QueryCurrentStateResponse
tuple := gomatrixserverlib.StateKeyTuple{
EventType: evType,
StateKey: stateKey,
}
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
}, &queryRes)
if err != nil {
return nil
}
return queryRes.StateEvents[tuple]
}
func (w *walker) publicRoomsChunk(roomID string) *fclient.PublicRoom {
pubRooms, err := roomserver.PopulatePublicRooms(w.ctx, []string{roomID}, w.rsAPI)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to PopulatePublicRooms")
return nil
}
if len(pubRooms) == 0 {
return nil
}
return &pubRooms[0]
}
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
// unsuccessful.
func (w *walker) federatedRoomInfo(roomID string, vias []string) *fclient.MSC2946SpacesResponse {
// only do federated requests for client requests
if w.caller == nil {
return nil
}
resp, ok := w.cache.GetSpaceSummary(roomID)
if ok {
util.GetLogger(w.ctx).Debugf("Returning cached response for %s", roomID)
return &resp
}
util.GetLogger(w.ctx).Debugf("Querying %s via %+v", roomID, vias)
ctx := context.Background()
// query more of the spaces graph using these servers
for _, serverName := range vias {
if serverName == string(w.thisServer) {
continue
}
res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, spec.ServerName(serverName), roomID, w.suggestedOnly)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
continue
}
// ensure nil slices are empty as we send this to the client sometimes
if res.Room.ChildrenState == nil {
res.Room.ChildrenState = []fclient.MSC2946StrippedEvent{}
}
for i := 0; i < len(res.Children); i++ {
child := res.Children[i]
if child.ChildrenState == nil {
child.ChildrenState = []fclient.MSC2946StrippedEvent{}
}
res.Children[i] = child
}
w.cache.StoreSpaceSummary(roomID, res)
return &res
}
return nil
}
func (w *walker) roomExists(roomID string) bool {
var queryRes roomserver.QueryServerJoinedToRoomResponse
err := w.rsAPI.QueryServerJoinedToRoom(w.ctx, &roomserver.QueryServerJoinedToRoomRequest{
RoomID: roomID,
ServerName: w.thisServer,
}, &queryRes)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryServerJoinedToRoom")
return false
}
// if the room exists but we aren't in the room then we might have stale data so we want to fetch
// it fresh via federation
return queryRes.RoomExists && queryRes.IsInRoom
}
// authorised returns true iff the user is joined this room or the room is world_readable
func (w *walker) authorised(roomID, parentRoomID string) (authed, isJoinedOrInvited bool) {
if w.caller != nil {
return w.authorisedUser(roomID, parentRoomID)
}
return w.authorisedServer(roomID), false
}
// authorisedServer returns true iff the server is joined this room or the room is world_readable, public, or knockable
func (w *walker) authorisedServer(roomID string) bool {
// Check history visibility / join rules first
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomHistoryVisibility,
StateKey: "",
}
joinRuleTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomJoinRules,
StateKey: "",
}
var queryRoomRes roomserver.QueryCurrentStateResponse
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{
hisVisTuple, joinRuleTuple,
},
}, &queryRoomRes)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryCurrentState")
return false
}
hisVisEv := queryRoomRes.StateEvents[hisVisTuple]
if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" {
return true
}
}
// check if this room is a restricted room and if so, we need to check if the server is joined to an allowed room ID
// in addition to the actual room ID (but always do the actual one first as it's quicker in the common case)
allowJoinedToRoomIDs := []string{roomID}
joinRuleEv := queryRoomRes.StateEvents[joinRuleTuple]
if joinRuleEv != nil {
rule, ruleErr := joinRuleEv.JoinRule()
if ruleErr != nil {
util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", roomID).Warn("failed to get join rule")
return false
}
if rule == spec.Public || rule == spec.Knock {
return true
}
if rule == spec.Restricted {
allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...)
}
}
// check if server is joined to any allowed room
for _, allowedRoomID := range allowJoinedToRoomIDs {
var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
err = w.fsAPI.QueryJoinedHostServerNamesInRoom(w.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
RoomID: allowedRoomID,
}, &queryRes)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom")
continue
}
for _, srv := range queryRes.ServerNames {
if srv == w.serverName {
return true
}
}
}
return false
}
// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable
// or if the room has a public or knock join rule.
// Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true.
func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoinedOrInvited bool) {
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomHistoryVisibility,
StateKey: "",
}
joinRuleTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomJoinRules,
StateKey: "",
}
roomMemberTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomMember,
StateKey: w.caller.UserID,
}
var queryRes roomserver.QueryCurrentStateResponse
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{
hisVisTuple, joinRuleTuple, roomMemberTuple,
},
}, &queryRes)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryCurrentState")
return false, false
}
memberEv := queryRes.StateEvents[roomMemberTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == spec.Join || membership == spec.Invite {
return true, true
}
}
hisVisEv := queryRes.StateEvents[hisVisTuple]
if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" {
return true, false
}
}
joinRuleEv := queryRes.StateEvents[joinRuleTuple]
if parentRoomID != "" && joinRuleEv != nil {
var allowed bool
rule, ruleErr := joinRuleEv.JoinRule()
if ruleErr != nil {
util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", parentRoomID).Warn("failed to get join rule")
} else if rule == spec.Public || rule == spec.Knock {
allowed = true
} else if rule == spec.Restricted {
allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")
// check parent is in the allowed set
for _, a := range allowedRoomIDs {
if parentRoomID == a {
allowed = true
break
}
}
}
if allowed {
// ensure caller is joined to the parent room
var queryRes2 roomserver.QueryCurrentStateResponse
err = w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
RoomID: parentRoomID,
StateTuples: []gomatrixserverlib.StateKeyTuple{
roomMemberTuple,
},
}, &queryRes2)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("parent_room_id", parentRoomID).Warn("failed to check user is joined to parent room")
} else {
memberEv = queryRes2.StateEvents[roomMemberTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == spec.Join {
return true, false
}
}
}
}
}
return false, false
}
func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *types.HeaderedEvent, allowType string) (allows []string) {
rule, _ := joinRuleEv.JoinRule()
if rule != spec.Restricted {
return nil
}
var jrContent gomatrixserverlib.JoinRuleContent
if err := json.Unmarshal(joinRuleEv.Content(), &jrContent); err != nil {
util.GetLogger(w.ctx).Warnf("failed to check join_rule on room %s: %s", joinRuleEv.RoomID(), err)
return nil
}
for _, allow := range jrContent.Allow {
if allow.Type == allowType {
allows = append(allows, allow.RoomID)
}
}
return
}
// references returns all child references pointing to or from this room.
func (w *walker) childReferences(roomID string) ([]fclient.MSC2946StrippedEvent, error) {
createTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomCreate,
StateKey: "",
}
var res roomserver.QueryCurrentStateResponse
err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
AllowWildcards: true,
StateTuples: []gomatrixserverlib.StateKeyTuple{
createTuple, {
EventType: ConstSpaceChildEventType,
StateKey: "*",
},
},
}, &res)
if err != nil {
return nil, err
}
// don't return any child refs if the room is not a space room
if res.StateEvents[createTuple] != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
if roomType != ConstCreateEventContentValueSpace {
return []fclient.MSC2946StrippedEvent{}, nil
}
}
delete(res.StateEvents, createTuple)
el := make([]fclient.MSC2946StrippedEvent, 0, len(res.StateEvents))
for _, ev := range res.StateEvents {
content := gjson.ParseBytes(ev.Content())
// only return events that have a `via` key as per MSC1772
// else we'll incorrectly walk redacted events (as the link
// is in the state_key)
if content.Get("via").Exists() {
strip := stripped(ev.PDU)
if strip == nil {
continue
}
// if suggested only and this child isn't suggested, skip it.
// if suggested only = false we include everything so don't need to check the content.
if w.suggestedOnly && !content.Get("suggested").Bool() {
continue
}
el = append(el, *strip)
}
}
// sort by origin_server_ts as per MSC2946
sort.Slice(el, func(i, j int) bool {
return el[i].OriginServerTS < el[j].OriginServerTS
})
return el, nil
}
type set map[string]struct{}
func (s set) set(val string) {
s[val] = struct{}{}
}
func (s set) isSet(val string) bool {
_, ok := s[val]
return ok
}
func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent {
if ev.StateKey() == nil {
return nil
}
return &fclient.MSC2946StrippedEvent{
Type: ev.Type(),
StateKey: *ev.StateKey(),
Content: ev.Content(),
Sender: string(ev.SenderID()),
OriginServerTS: ev.OriginServerTS(),
}
}
func parseInt(intstr string, defaultVal int) int {
i, err := strconv.ParseInt(intstr, 10, 32)
if err != nil {
return defaultVal
}
return int(i)
}

View file

@ -17,7 +17,6 @@ package mscs
import ( import (
"context" "context"
"fmt"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
@ -25,12 +24,12 @@ import (
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs/msc2836" "github.com/matrix-org/dendrite/setup/mscs/msc2836"
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// Enable MSCs - returns an error on unknown MSCs // Enable MSCs - returns an error on unknown MSCs
func Enable(cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, caches *caching.Caches) error { func Enable(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, caches *caching.Caches) error {
for _, msc := range cfg.MSCs.MSCs { for _, msc := range cfg.MSCs.MSCs {
util.GetLogger(context.Background()).WithField("msc", msc).Info("Enabling MSC") util.GetLogger(context.Background()).WithField("msc", msc).Info("Enabling MSC")
if err := EnableMSC(cfg, cm, routers, monolith, msc, caches); err != nil { if err := EnableMSC(cfg, cm, routers, monolith, msc, caches); err != nil {
@ -40,16 +39,14 @@ func Enable(cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Route
return nil return nil
} }
func EnableMSC(cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, msc string, caches *caching.Caches) error { func EnableMSC(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, msc string, caches *caching.Caches) error {
switch msc { switch msc {
case "msc2836": case "msc2836":
return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing)
case "msc2946":
return msc2946.Enable(cfg, routers, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, caches)
case "msc2444": // enabled inside federationapi case "msc2444": // enabled inside federationapi
case "msc2753": // enabled inside clientapi case "msc2753": // enabled inside clientapi
default: default:
return fmt.Errorf("EnableMSC: unknown msc '%s'", msc) logrus.Warnf("EnableMSC: unknown MSC '%s', this MSC is either not supported or is natively supported by Dendrite", msc)
} }
return nil return nil
} }

View file

@ -113,7 +113,7 @@ func (s *OutputClientDataConsumer) Start() error {
id = streamPos id = streamPos
e := fulltext.IndexElement{ e := fulltext.IndexElement{
EventID: ev.EventID(), EventID: ev.EventID(),
RoomID: ev.RoomID(), RoomID: ev.RoomID().String(),
StreamPosition: streamPos, StreamPosition: streamPos,
} }
e.SetContentType(ev.Type()) e.SetContentType(ev.Type())

View file

@ -17,16 +17,12 @@ package consumers
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -37,7 +33,13 @@ import (
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
) )
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
@ -141,7 +143,14 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
) )
} }
if err != nil { if err != nil {
log.WithError(err).Error("roomserver output log: failed to process event") if errors.As(err, new(base64.CorruptInputError)) {
// no matter how often we retry this event, we will always get this error, discard the event
return true
}
log.WithFields(log.Fields{
"type": output.Type,
}).WithError(err).Error("roomserver output log: failed to process event")
sentry.CaptureException(err)
return false return false
} }
@ -157,9 +166,9 @@ func (s *OutputRoomEventConsumer) onRedactEvent(
return err return err
} }
if err = s.db.RedactRelations(ctx, msg.RedactedBecause.RoomID(), msg.RedactedEventID); err != nil { if err = s.db.RedactRelations(ctx, msg.RedactedBecause.RoomID().String(), msg.RedactedEventID); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"room_id": msg.RedactedBecause.RoomID(), "room_id": msg.RedactedBecause.RoomID().String(),
"event_id": msg.RedactedBecause.EventID(), "event_id": msg.RedactedBecause.EventID(),
"redacted_event_id": msg.RedactedEventID, "redacted_event_id": msg.RedactedEventID,
}).WithError(err).Warn("Failed to redact relations") }).WithError(err).Warn("Failed to redact relations")
@ -213,7 +222,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
// Finally, work out if there are any more events missing. // Finally, work out if there are any more events missing.
if len(missingEventIDs) > 0 { if len(missingEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: ev.RoomID(), RoomID: ev.RoomID().String(),
EventIDs: missingEventIDs, EventIDs: missingEventIDs,
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}
@ -237,31 +246,23 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
ev, err := s.updateStateEvent(ev) ev, err := s.updateStateEvent(ev)
if err != nil { if err != nil {
sentry.CaptureException(err)
return err return err
} }
for i := range addsStateEvents { for i := range addsStateEvents {
addsStateEvents[i], err = s.updateStateEvent(addsStateEvents[i]) addsStateEvents[i], err = s.updateStateEvent(addsStateEvents[i])
if err != nil { if err != nil {
sentry.CaptureException(err)
return err return err
} }
} }
if msg.RewritesState { if msg.RewritesState {
if err = s.db.PurgeRoomState(ctx, ev.RoomID()); err != nil { if err = s.db.PurgeRoomState(ctx, ev.RoomID().String()); err != nil {
sentry.CaptureException(err)
return fmt.Errorf("s.db.PurgeRoom: %w", err) return fmt.Errorf("s.db.PurgeRoom: %w", err)
} }
} }
validRoomID, err := spec.NewRoomID(ev.RoomID()) userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err != nil {
return err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, ev.SenderID())
if err != nil { if err != nil {
return err return err
} }
@ -289,7 +290,6 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
sentry.CaptureException(err)
return err return err
} }
@ -302,7 +302,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
} }
s.pduStream.Advance(pduPos) s.pduStream.Advance(pduPos)
s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos}) s.notifier.OnNewEvent(ev, ev.RoomID().String(), nil, types.StreamingToken{PDUPosition: pduPos})
return nil return nil
} }
@ -319,12 +319,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
// old events in the sync API, this should at least prevent us // old events in the sync API, this should at least prevent us
// from confusing clients into thinking they've joined/left rooms. // from confusing clients into thinking they've joined/left rooms.
validRoomID, err := spec.NewRoomID(ev.RoomID()) userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err != nil {
return err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, ev.SenderID())
if err != nil { if err != nil {
return err return err
} }
@ -350,7 +345,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
if err = s.db.UpdateRelations(ctx, ev); err != nil { if err = s.db.UpdateRelations(ctx, ev); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"room_id": ev.RoomID(), "room_id": ev.RoomID().String(),
"event_id": ev.EventID(), "event_id": ev.EventID(),
"type": ev.Type(), "type": ev.Type(),
}).WithError(err).Warn("Failed to update relations") }).WithError(err).Warn("Failed to update relations")
@ -363,7 +358,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
} }
s.pduStream.Advance(pduPos) s.pduStream.Advance(pduPos)
s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos}) s.notifier.OnNewEvent(ev, ev.RoomID().String(), nil, types.StreamingToken{PDUPosition: pduPos})
return nil return nil
} }
@ -383,11 +378,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
return sp, fmt.Errorf("unexpected nil state_key") return sp, fmt.Errorf("unexpected nil state_key")
} }
validRoomID, err := spec.NewRoomID(ev.RoomID()) userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil {
return sp, err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey()))
if err != nil || userID == nil { if err != nil || userID == nil {
return sp, fmt.Errorf("failed getting userID for sender: %w", err) return sp, fmt.Errorf("failed getting userID for sender: %w", err)
} }
@ -396,7 +387,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
} }
// cancel any peeks for it // cancel any peeks for it
peekSP, peekErr := s.db.DeletePeeks(ctx, ev.RoomID(), *ev.StateKey()) peekSP, peekErr := s.db.DeletePeeks(ctx, ev.RoomID().String(), *ev.StateKey())
if peekErr != nil { if peekErr != nil {
return sp, fmt.Errorf("s.db.DeletePeeks: %w", peekErr) return sp, fmt.Errorf("s.db.DeletePeeks: %w", peekErr)
} }
@ -414,11 +405,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
return return
} }
validRoomID, err := spec.NewRoomID(msg.Event.RoomID()) userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
if err != nil {
return
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey()))
if err != nil || userID == nil { if err != nil || userID == nil {
return return
} }
@ -430,7 +417,6 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil { if err != nil {
sentry.CaptureException(err)
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": msg.Event.EventID(), "event_id": msg.Event.EventID(),
@ -452,7 +438,6 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// It's possible we just haven't heard of this invite yet, so // It's possible we just haven't heard of this invite yet, so
// we should not panic if we try to retire it. // we should not panic if we try to retire it.
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
sentry.CaptureException(err)
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": msg.EventID, "event_id": msg.EventID,
@ -496,7 +481,6 @@ func (s *OutputRoomEventConsumer) onNewPeek(
) { ) {
sp, err := s.db.AddPeek(ctx, msg.RoomID, msg.UserID, msg.DeviceID) sp, err := s.db.AddPeek(ctx, msg.RoomID, msg.UserID, msg.DeviceID)
if err != nil { if err != nil {
sentry.CaptureException(err)
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
log.ErrorKey: err, log.ErrorKey: err,
@ -558,30 +542,24 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
var succeeded bool var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
sKeyUser := ""
if stateKey != "" {
var sku *spec.UserID
sku, err = s.rsAPI.QueryUserIDForSender(s.ctx, event.RoomID(), spec.SenderID(stateKey))
if err == nil && sku != nil {
sKeyUser = sku.String()
event.StateKeyResolved = &sKeyUser
}
}
prevEvent, err := snapshot.GetStateEvent( prevEvent, err := snapshot.GetStateEvent(
s.ctx, event.RoomID(), event.Type(), stateKey, s.ctx, event.RoomID().String(), event.Type(), sKeyUser,
) )
if err != nil { if err != nil {
return event, err return event, err
} }
validRoomID, err := spec.NewRoomID(event.RoomID()) userID, err := s.rsAPI.QueryUserIDForSender(s.ctx, event.RoomID(), event.SenderID())
if err != nil {
return event, err
}
if event.StateKey() != nil {
if *event.StateKey() != "" {
var sku *spec.UserID
sku, err = s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, spec.SenderID(stateKey))
if err == nil && sku != nil {
sKey := sku.String()
event.StateKeyResolved = &sKey
}
}
}
userID, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return event, err return event, err
} }
@ -592,7 +570,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
return event, nil return event, nil
} }
prev := types.PrevEventRef{ prev := synctypes.PrevEventRef{
PrevContent: prevEvent.Content(), PrevContent: prevEvent.Content(),
ReplacesState: prevEvent.EventID(), ReplacesState: prevEvent.EventID(),
PrevSenderID: string(prevEvent.SenderID()), PrevSenderID: string(prevEvent.SenderID()),
@ -609,7 +587,7 @@ func (s *OutputRoomEventConsumer) writeFTS(ev *rstypes.HeaderedEvent, pduPositio
} }
e := fulltext.IndexElement{ e := fulltext.IndexElement{
EventID: ev.EventID(), EventID: ev.EventID(),
RoomID: ev.RoomID(), RoomID: ev.RoomID().String(),
StreamPosition: int64(pduPosition), StreamPosition: int64(pduPosition),
} }
e.SetContentType(ev.Type()) e.SetContentType(ev.Type())

View file

@ -16,6 +16,7 @@ package internal
import ( import (
"context" "context"
"fmt"
"math" "math"
"time" "time"
@ -101,13 +102,15 @@ func (ev eventVisibility) allowed() (allowed bool) {
// ApplyHistoryVisibilityFilter applies the room history visibility filter on types.HeaderedEvents. // ApplyHistoryVisibilityFilter applies the room history visibility filter on types.HeaderedEvents.
// Returns the filtered events and an error, if any. // Returns the filtered events and an error, if any.
//
// This function assumes that all provided events are from the same room.
func ApplyHistoryVisibilityFilter( func ApplyHistoryVisibilityFilter(
ctx context.Context, ctx context.Context,
syncDB storage.DatabaseTransaction, syncDB storage.DatabaseTransaction,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*types.HeaderedEvent, events []*types.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{}, alwaysIncludeEventIDs map[string]struct{},
userID, endpoint string, userID spec.UserID, endpoint string,
) ([]*types.HeaderedEvent, error) { ) ([]*types.HeaderedEvent, error) {
if len(events) == 0 { if len(events) == 0 {
return events, nil return events, nil
@ -115,15 +118,26 @@ func ApplyHistoryVisibilityFilter(
start := time.Now() start := time.Now()
// try to get the current membership of the user // try to get the current membership of the user
membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64) membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID().String(), userID.String(), math.MaxInt64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Get the mapping from eventID -> eventVisibility // Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*types.HeaderedEvent, 0, len(events)) eventsFiltered := make([]*types.HeaderedEvent, 0, len(events))
visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) firstEvRoomID := events[0].RoomID()
senderID, err := rsAPI.QuerySenderIDForUser(ctx, firstEvRoomID, userID)
if err != nil {
return nil, err
}
visibilities := visibilityForEvents(ctx, rsAPI, events, senderID, firstEvRoomID)
for _, ev := range events { for _, ev := range events {
// Validate same room assumption
if ev.RoomID().String() != firstEvRoomID.String() {
return nil, fmt.Errorf("events from different rooms supplied to ApplyHistoryVisibilityFilter")
}
evVis := visibilities[ev.EventID()] evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent evVis.membershipCurrent = membershipCurrent
// Always include specific state events for /sync responses // Always include specific state events for /sync responses
@ -133,38 +147,36 @@ func ApplyHistoryVisibilityFilter(
continue continue
} }
} }
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
user, err := spec.NewUserID(userID, true) // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
if err != nil { if senderID != nil {
return nil, err if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(*senderID)) {
}
roomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user)
if err == nil {
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
eventsFiltered = append(eventsFiltered, ev) eventsFiltered = append(eventsFiltered, ev)
continue continue
} }
} }
// Always allow history evVis events on boundaries. This is done // Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive // by setting the effective evVis to the least restrictive
// of the old vs new. // of the old vs new.
// https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5 // https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
if hisVis, err := ev.HistoryVisibility(); err == nil { if ev.Type() == spec.MRoomHistoryVisibility {
prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String() hisVis, err := ev.HistoryVisibility()
oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)]
// if we can't get the previous history visibility, default to shared. if err == nil && hisVis != "" {
if !ok { prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String()
oldPrio = historyVisibilityPriority[gomatrixserverlib.HistoryVisibilityShared] oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)]
} // if we can't get the previous history visibility, default to shared.
// no OK check, since this should have been validated when setting the value if !ok {
newPrio := historyVisibilityPriority[hisVis] oldPrio = historyVisibilityPriority[gomatrixserverlib.HistoryVisibilityShared]
if oldPrio < newPrio { }
evVis.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis) // no OK check, since this should have been validated when setting the value
newPrio := historyVisibilityPriority[hisVis]
if oldPrio < newPrio {
evVis.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis)
} else {
evVis.visibility = hisVis
}
} }
} }
// do the actual check // do the actual check
@ -178,13 +190,13 @@ func ApplyHistoryVisibilityFilter(
} }
// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership // visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership
// of `userID` at the given event. // of `senderID` at the given event. If provided sender ID is nil, assume that membership is Leave
// Returns an error if the roomserver can't calculate the memberships. // Returns an error if the roomserver can't calculate the memberships.
func visibilityForEvents( func visibilityForEvents(
ctx context.Context, ctx context.Context,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*types.HeaderedEvent, events []*types.HeaderedEvent,
userID, roomID string, senderID *spec.SenderID, roomID spec.RoomID,
) map[string]eventVisibility { ) map[string]eventVisibility {
eventIDs := make([]string, len(events)) eventIDs := make([]string, len(events))
for i := range events { for i := range events {
@ -194,15 +206,13 @@ func visibilityForEvents(
result := make(map[string]eventVisibility, len(eventIDs)) result := make(map[string]eventVisibility, len(eventIDs))
// get the membership events for all eventIDs // get the membership events for all eventIDs
membershipResp := &api.QueryMembershipAtEventResponse{} var err error
membershipEvents := make(map[string]*types.HeaderedEvent)
err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ if senderID != nil {
RoomID: roomID, membershipEvents, err = rsAPI.QueryMembershipAtEvent(ctx, roomID, eventIDs, *senderID)
EventIDs: eventIDs, if err != nil {
UserID: userID, logrus.WithError(err).Error("visibilityForEvents: failed to fetch membership at event, defaulting to 'leave'")
}, membershipResp) }
if err != nil {
logrus.WithError(err).Error("visibilityForEvents: failed to fetch membership at event, defaulting to 'leave'")
} }
// Create a map from eventID -> eventVisibility // Create a map from eventID -> eventVisibility
@ -212,7 +222,7 @@ func visibilityForEvents(
membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility, visibility: event.Visibility,
} }
ev, ok := membershipResp.Membership[eventID] ev, ok := membershipEvents[eventID]
if !ok || ev == nil { if !ok || ev == nil {
result[eventID] = vis result[eventID] = vis
continue continue

View file

@ -0,0 +1,214 @@
package internal
import (
"context"
"fmt"
"math"
"testing"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"gotest.tools/v3/assert"
)
type mockHisVisRoomserverAPI struct {
rsapi.RoomserverInternalAPI
events []*types.HeaderedEvent
roomID string
}
func (s *mockHisVisRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, roomID spec.RoomID, eventIDs []string, senderID spec.SenderID) (map[string]*types.HeaderedEvent, error) {
if roomID.String() == s.roomID {
membershipMap := map[string]*types.HeaderedEvent{}
for _, queriedEventID := range eventIDs {
for _, event := range s.events {
if event.EventID() == queriedEventID {
membershipMap[queriedEventID] = event
}
}
}
return membershipMap, nil
} else {
return nil, fmt.Errorf("room not found: \"%v\"", roomID)
}
}
func (s *mockHisVisRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
senderID := spec.SenderIDFromUserID(userID)
return &senderID, nil
}
func (s *mockHisVisRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
userID := senderID.ToUserID()
if userID == nil {
return nil, fmt.Errorf("sender ID not user ID")
}
return userID, nil
}
type mockDB struct {
storage.DatabaseTransaction
// user ID -> membership (i.e. 'join', 'leave', etc.)
currentMembership map[string]string
roomID string
}
func (s *mockDB) SelectMembershipForUser(ctx context.Context, roomID string, userID string, pos int64) (string, int64, error) {
if roomID == s.roomID {
membership, ok := s.currentMembership[userID]
if !ok {
return spec.Leave, math.MaxInt64, nil
}
return membership, math.MaxInt64, nil
}
return "", 0, fmt.Errorf("room not found: \"%v\"", roomID)
}
// Tests logic around history visibility boundaries
//
// Specifically that if a room's history visibility before or after a particular history visibility event
// allows them to see events (a boundary), then the history visibility event itself should be shown
// ( spec: https://spec.matrix.org/v1.8/client-server-api/#server-behaviour-5 )
//
// This also aims to emulate "Only see history_visibility changes on bounadries" in sytest/tests/30rooms/30history-visibility.pl
func Test_ApplyHistoryVisbility_Boundaries(t *testing.T) {
ctx := context.Background()
roomID := "!roomid:domain"
creatorUserID := spec.NewUserIDOrPanic("@creator:domain", false)
otherUserID := spec.NewUserIDOrPanic("@other:domain", false)
roomVersion := gomatrixserverlib.RoomVersionV10
roomVerImpl := gomatrixserverlib.MustGetRoomVersion(roomVersion)
eventsJSON := []struct {
id string
json string
}{
{id: "$create-event", json: fmt.Sprintf(`{
"type": "m.room.create", "state_key": "",
"room_id": "%v", "sender": "%v",
"content": {"creator": "%v", "room_version": "%v"}
}`, roomID, creatorUserID.String(), creatorUserID.String(), roomVersion)},
{id: "$creator-joined", json: fmt.Sprintf(`{
"type": "m.room.member", "state_key": "%v",
"room_id": "%v", "sender": "%v",
"content": {"membership": "join"}
}`, creatorUserID.String(), roomID, creatorUserID.String())},
{id: "$hisvis-1", json: fmt.Sprintf(`{
"type": "m.room.history_visibility", "state_key": "",
"room_id": "%v", "sender": "%v",
"content": {"history_visibility": "shared"}
}`, roomID, creatorUserID.String())},
{id: "$msg-1", json: fmt.Sprintf(`{
"type": "m.room.message",
"room_id": "%v", "sender": "%v",
"content": {"body": "1"}
}`, roomID, creatorUserID.String())},
{id: "$hisvis-2", json: fmt.Sprintf(`{
"type": "m.room.history_visibility", "state_key": "",
"room_id": "%v", "sender": "%v",
"content": {"history_visibility": "joined"},
"unsigned": {"prev_content": {"history_visibility": "shared"}}
}`, roomID, creatorUserID.String())},
{id: "$msg-2", json: fmt.Sprintf(`{
"type": "m.room.message",
"room_id": "%v", "sender": "%v",
"content": {"body": "1"}
}`, roomID, creatorUserID.String())},
{id: "$hisvis-3", json: fmt.Sprintf(`{
"type": "m.room.history_visibility", "state_key": "",
"room_id": "%v", "sender": "%v",
"content": {"history_visibility": "invited"},
"unsigned": {"prev_content": {"history_visibility": "joined"}}
}`, roomID, creatorUserID.String())},
{id: "$msg-3", json: fmt.Sprintf(`{
"type": "m.room.message",
"room_id": "%v", "sender": "%v",
"content": {"body": "2"}
}`, roomID, creatorUserID.String())},
{id: "$hisvis-4", json: fmt.Sprintf(`{
"type": "m.room.history_visibility", "state_key": "",
"room_id": "%v", "sender": "%v",
"content": {"history_visibility": "shared"},
"unsigned": {"prev_content": {"history_visibility": "invited"}}
}`, roomID, creatorUserID.String())},
{id: "$msg-4", json: fmt.Sprintf(`{
"type": "m.room.message",
"room_id": "%v", "sender": "%v",
"content": {"body": "3"}
}`, roomID, creatorUserID.String())},
{id: "$other-joined", json: fmt.Sprintf(`{
"type": "m.room.member", "state_key": "%v",
"room_id": "%v", "sender": "%v",
"content": {"membership": "join"}
}`, otherUserID.String(), roomID, otherUserID.String())},
}
events := make([]*types.HeaderedEvent, len(eventsJSON))
hisVis := gomatrixserverlib.HistoryVisibilityShared
for i, eventJSON := range eventsJSON {
pdu, err := roomVerImpl.NewEventFromTrustedJSONWithEventID(eventJSON.id, []byte(eventJSON.json), false)
if err != nil {
t.Fatalf("failed to prepare event %s for test: %s", eventJSON.id, err.Error())
}
events[i] = &types.HeaderedEvent{PDU: pdu}
// 'Visibility' should be the visibility of the room just before this event was sent
// (according to processRoomEvent in roomserver/internal/input/input_events.go)
events[i].Visibility = hisVis
if pdu.Type() == spec.MRoomHistoryVisibility {
newHisVis, err := pdu.HistoryVisibility()
if err != nil {
t.Fatalf("failed to prepare history visibility event: %s", err.Error())
}
hisVis = newHisVis
}
}
rsAPI := &mockHisVisRoomserverAPI{
events: events,
roomID: roomID,
}
syncDB := &mockDB{
roomID: roomID,
currentMembership: map[string]string{
creatorUserID.String(): spec.Join,
otherUserID.String(): spec.Join,
},
}
filteredEvents, err := ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, events, nil, otherUserID, "hisVisTest")
if err != nil {
t.Fatalf("ApplyHistoryVisibility returned non-nil error: %s", err.Error())
}
filteredEventIDs := make([]string, len(filteredEvents))
for i, event := range filteredEvents {
filteredEventIDs[i] = event.EventID()
}
assert.DeepEqual(t,
[]string{
"$create-event", // Always see m.room.create
"$creator-joined", // Always see membership
"$hisvis-1", // Sets room to shared (technically the room is already shared since shared is default)
"$msg-1", // Room currently 'shared'
"$hisvis-2", // Room changed from 'shared' to 'joined', so boundary event and should be shared
// Other events hidden, as other is not joined yet
// hisvis-3 is also hidden, as it changes from joined to invited, neither of which is visible to other
"$hisvis-4", // Changes from 'invited' to 'shared', so is a boundary event and visible
"$msg-4", // Room is 'shared', so visible
"$other-joined", // other's membership
},
filteredEventIDs,
)
}

View file

@ -101,21 +101,14 @@ func (n *Notifier) OnNewEvent(
n._removeEmptyUserStreams() n._removeEmptyUserStreams()
if ev != nil { if ev != nil {
validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: RoomID is invalid",
)
return
}
// Map this event's room_id to a list of joined users, and wake them up. // Map this event's room_id to a list of joined users, and wake them up.
usersToNotify := n._joinedUsers(ev.RoomID()) usersToNotify := n._joinedUsers(ev.RoomID().String())
// Map this event's room_id to a list of peeking devices, and wake them up. // Map this event's room_id to a list of peeking devices, and wake them up.
peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) peekingDevicesToNotify := n._peekingDevices(ev.RoomID().String())
// If this is an invite, also add in the invitee to this list. // If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil { if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey())) targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil { if err != nil || targetUserID == nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf( log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: Failed to find the userID for this event", "Notifier.OnNewEvent: Failed to find the userID for this event",
) )
@ -134,11 +127,11 @@ func (n *Notifier) OnNewEvent(
// Manually append the new user's ID so they get notified // Manually append the new user's ID so they get notified
// along all members in the room // along all members in the room
usersToNotify = append(usersToNotify, targetUserID.String()) usersToNotify = append(usersToNotify, targetUserID.String())
n._addJoinedUser(ev.RoomID(), targetUserID.String()) n._addJoinedUser(ev.RoomID().String(), targetUserID.String())
case spec.Leave: case spec.Leave:
fallthrough fallthrough
case spec.Ban: case spec.Ban:
n._removeJoinedUser(ev.RoomID(), targetUserID.String()) n._removeJoinedUser(ev.RoomID().String(), targetUserID.String())
} }
} }
} }

View file

@ -138,7 +138,7 @@ func Context(
// verify the user is allowed to see the context for this room/event // verify the user is allowed to see the context for this room/event
startTime := time.Now() startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, *userID, "context")
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
return util.JSONResponse{ return util.JSONResponse{
@ -176,7 +176,7 @@ func Context(
} }
startTime = time.Now() startTime = time.Now()
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID) eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, *userID)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
return util.JSONResponse{ return util.JSONResponse{
@ -257,7 +257,7 @@ func Context(
func applyHistoryVisibilityOnContextEvents( func applyHistoryVisibilityOnContextEvents(
ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI, ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*rstypes.HeaderedEvent, eventsBefore, eventsAfter []*rstypes.HeaderedEvent,
userID string, userID spec.UserID,
) (filteredBefore, filteredAfter []*rstypes.HeaderedEvent, err error) { ) (filteredBefore, filteredAfter []*rstypes.HeaderedEvent, err error) {
eventIDsBefore := make(map[string]struct{}, len(eventsBefore)) eventIDsBefore := make(map[string]struct{}, len(eventsBefore))
eventIDsAfter := make(map[string]struct{}, len(eventsAfter)) eventIDsAfter := make(map[string]struct{}, len(eventsAfter))

Some files were not shown because too many files have changed in this diff Show more