mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-11 08:03:09 -06:00
Merge branch 'main' into token-registration
This commit is contained in:
commit
19fc59d3d2
|
|
@ -47,7 +47,7 @@ For a usable federating Dendrite deployment, you will also need:
|
|||
Also recommended are:
|
||||
|
||||
- A PostgreSQL database engine, which will perform better than SQLite with many users and/or larger rooms
|
||||
- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf)
|
||||
- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/dendrite-sample.conf)
|
||||
|
||||
The [Federation Tester](https://federationtester.matrix.org) can be used to verify your deployment.
|
||||
|
||||
|
|
|
|||
|
|
@ -145,8 +145,16 @@ func SaveReadMarker(
|
|||
userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string,
|
||||
) util.JSONResponse {
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("userID for this device is invalid"),
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the user is a member of this room
|
||||
resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
|
||||
resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,9 +55,16 @@ func GetAliases(
|
|||
visibility = content.HistoryVisibility
|
||||
}
|
||||
if visibility != spec.WorldReadable {
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
|
||||
}
|
||||
}
|
||||
queryReq := api.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: device.UserID,
|
||||
UserID: *deviceUserID,
|
||||
}
|
||||
var queryRes api.QueryMembershipForUserResponse
|
||||
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
|
||||
|
|
|
|||
|
|
@ -224,6 +224,7 @@ func createRoom(
|
|||
PrivateKey: privateKey,
|
||||
EventTime: evTime,
|
||||
}
|
||||
|
||||
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
|
||||
if createRes != nil {
|
||||
return *createRes
|
||||
|
|
|
|||
|
|
@ -314,7 +314,22 @@ func SetVisibility(
|
|||
req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device,
|
||||
roomID string,
|
||||
) util.JSONResponse {
|
||||
resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID)
|
||||
deviceUserID, err := spec.NewUserID(dev.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("userID for this device is invalid"),
|
||||
}
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown("failed to find senderID for this user"),
|
||||
}
|
||||
}
|
||||
|
||||
resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
|
@ -327,7 +342,7 @@ func SetVisibility(
|
|||
}},
|
||||
}
|
||||
var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse
|
||||
err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes)
|
||||
err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes)
|
||||
if err != nil || len(queryEventsRes.StateEvents) == 0 {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("could not query events from room")
|
||||
return util.JSONResponse{
|
||||
|
|
@ -338,20 +353,6 @@ func SetVisibility(
|
|||
|
||||
// NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event
|
||||
power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU)
|
||||
fullUserID, err := spec.NewUserID(dev.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
|
||||
}
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
|
||||
}
|
||||
}
|
||||
if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
|
|
|
|||
|
|
@ -29,10 +29,18 @@ func LeaveRoomByID(
|
|||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
roomID string,
|
||||
) util.JSONResponse {
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown("device userID is invalid"),
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare to ask the roomserver to perform the room join.
|
||||
leaveReq := roomserverAPI.PerformLeaveRequest{
|
||||
RoomID: roomID,
|
||||
UserID: device.UserID,
|
||||
Leaver: *userID,
|
||||
}
|
||||
leaveRes := roomserverAPI.PerformLeaveResponse{}
|
||||
|
||||
|
|
|
|||
|
|
@ -57,7 +57,22 @@ func SendBan(
|
|||
}
|
||||
}
|
||||
|
||||
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
|
||||
}
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
|
||||
}
|
||||
}
|
||||
|
||||
errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
|
||||
if errRes != nil {
|
||||
return *errRes
|
||||
}
|
||||
|
|
@ -66,20 +81,6 @@ func SendBan(
|
|||
if errRes != nil {
|
||||
return *errRes
|
||||
}
|
||||
fullUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
|
||||
}
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
|
||||
}
|
||||
}
|
||||
allowedToBan := pl.UserLevel(senderID) >= pl.Ban
|
||||
if !allowedToBan {
|
||||
return util.JSONResponse{
|
||||
|
|
@ -147,7 +148,22 @@ func SendKick(
|
|||
}
|
||||
}
|
||||
|
||||
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
|
||||
}
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
|
||||
}
|
||||
}
|
||||
|
||||
errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
|
||||
if errRes != nil {
|
||||
return *errRes
|
||||
}
|
||||
|
|
@ -156,20 +172,6 @@ func SendKick(
|
|||
if errRes != nil {
|
||||
return *errRes
|
||||
}
|
||||
fullUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
|
||||
}
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
|
||||
}
|
||||
}
|
||||
allowedToKick := pl.UserLevel(senderID) >= pl.Kick
|
||||
if !allowedToKick {
|
||||
return util.JSONResponse{
|
||||
|
|
@ -178,10 +180,17 @@ func SendKick(
|
|||
}
|
||||
}
|
||||
|
||||
bodyUserID, err := spec.NewUserID(body.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("body userID is invalid"),
|
||||
}
|
||||
}
|
||||
var queryRes roomserverAPI.QueryMembershipForUserResponse
|
||||
err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: body.UserID,
|
||||
UserID: *bodyUserID,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
@ -213,15 +222,30 @@ func SendUnban(
|
|||
}
|
||||
}
|
||||
|
||||
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
|
||||
}
|
||||
}
|
||||
|
||||
errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
|
||||
if errRes != nil {
|
||||
return *errRes
|
||||
}
|
||||
|
||||
bodyUserID, err := spec.NewUserID(body.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadJSON("body userID is invalid"),
|
||||
}
|
||||
}
|
||||
var queryRes roomserverAPI.QueryMembershipForUserResponse
|
||||
err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
|
||||
err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: body.UserID,
|
||||
UserID: *bodyUserID,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
@ -272,7 +296,15 @@ func SendInvite(
|
|||
}
|
||||
}
|
||||
|
||||
errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
|
||||
}
|
||||
}
|
||||
|
||||
errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
|
||||
if errRes != nil {
|
||||
return *errRes
|
||||
}
|
||||
|
|
@ -340,17 +372,18 @@ func sendInvite(
|
|||
|
||||
func buildMembershipEventDirect(
|
||||
ctx context.Context,
|
||||
targetUserID, reason string, userDisplayName, userAvatarURL string,
|
||||
sender string, senderDomain spec.ServerName,
|
||||
targetSenderID spec.SenderID, reason string, userDisplayName, userAvatarURL string,
|
||||
sender spec.SenderID, senderDomain spec.ServerName,
|
||||
membership, roomID string, isDirect bool,
|
||||
keyID gomatrixserverlib.KeyID, privateKey ed25519.PrivateKey, evTime time.Time,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
) (*types.HeaderedEvent, error) {
|
||||
targetSenderString := string(targetSenderID)
|
||||
proto := gomatrixserverlib.ProtoEvent{
|
||||
SenderID: sender,
|
||||
SenderID: string(sender),
|
||||
RoomID: roomID,
|
||||
Type: "m.room.member",
|
||||
StateKey: &targetUserID,
|
||||
StateKey: &targetSenderString,
|
||||
}
|
||||
|
||||
content := gomatrixserverlib.MemberContent{
|
||||
|
|
@ -391,8 +424,25 @@ func buildMembershipEvent(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return buildMembershipEventDirect(ctx, targetUserID, reason, profile.DisplayName, profile.AvatarURL,
|
||||
device.UserID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI)
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetID, err := spec.NewUserID(targetUserID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buildMembershipEventDirect(ctx, targetSenderID, reason, profile.DisplayName, profile.AvatarURL,
|
||||
senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI)
|
||||
}
|
||||
|
||||
// loadProfile lookups the profile of a given user from the database and returns
|
||||
|
|
@ -490,7 +540,7 @@ func checkAndProcessThreepid(
|
|||
return
|
||||
}
|
||||
|
||||
func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse {
|
||||
func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID spec.UserID, roomID string) *util.JSONResponse {
|
||||
var membershipRes roomserverAPI.QueryMembershipForUserResponse
|
||||
err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
|
|
@ -518,12 +568,21 @@ func SendForget(
|
|||
) util.JSONResponse {
|
||||
ctx := req.Context()
|
||||
logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID)
|
||||
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
|
||||
}
|
||||
}
|
||||
|
||||
var membershipRes roomserverAPI.QueryMembershipForUserResponse
|
||||
membershipReq := roomserverAPI.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: device.UserID,
|
||||
UserID: *deviceUserID,
|
||||
}
|
||||
err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes)
|
||||
err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user")
|
||||
return util.JSONResponse{
|
||||
|
|
|
|||
|
|
@ -47,7 +47,22 @@ func SendRedaction(
|
|||
txnID *string,
|
||||
txnCache *transactions.Cache,
|
||||
) util.JSONResponse {
|
||||
resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
|
||||
deviceUserID, userIDErr := spec.NewUserID(device.UserID, true)
|
||||
if userIDErr != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to redact"),
|
||||
}
|
||||
}
|
||||
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
|
||||
if queryErr != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to redact"),
|
||||
}
|
||||
}
|
||||
|
||||
resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
|
@ -73,25 +88,10 @@ func SendRedaction(
|
|||
}
|
||||
}
|
||||
|
||||
fullUserID, userIDErr := spec.NewUserID(device.UserID, true)
|
||||
if userIDErr != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to redact"),
|
||||
}
|
||||
}
|
||||
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
|
||||
if queryErr != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to redact"),
|
||||
}
|
||||
}
|
||||
|
||||
// "Users may redact their own events, and any user with a power level greater than or equal
|
||||
// to the redact power level of the room may redact events there"
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
|
||||
allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey
|
||||
allowedToRedact := ev.SenderID() == senderID
|
||||
if !allowedToRedact {
|
||||
plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{
|
||||
EventType: spec.MRoomPowerLevels,
|
||||
|
|
|
|||
|
|
@ -43,8 +43,16 @@ func SendTyping(
|
|||
}
|
||||
}
|
||||
|
||||
deviceUserID, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the user is a member of this room
|
||||
resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID)
|
||||
resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ type sendServerNoticeRequest struct {
|
|||
StateKey string `json:"state_key,omitempty"`
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
// SendServerNotice sends a message to a specific user. It can only be invoked by an admin.
|
||||
func SendServerNotice(
|
||||
req *http.Request,
|
||||
|
|
@ -187,9 +188,17 @@ func SendServerNotice(
|
|||
}
|
||||
} else {
|
||||
// we've found a room in common, check the membership
|
||||
deviceUserID, err := spec.NewUserID(r.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
|
||||
}
|
||||
}
|
||||
|
||||
roomID = commonRooms[0]
|
||||
membershipRes := api.QueryMembershipForUserResponse{}
|
||||
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes)
|
||||
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user")
|
||||
return util.JSONResponse{
|
||||
|
|
@ -234,7 +243,7 @@ func SendServerNotice(
|
|||
ctx, rsAPI,
|
||||
api.KindNew,
|
||||
[]*types.HeaderedEvent{
|
||||
&types.HeaderedEvent{PDU: e},
|
||||
{PDU: e},
|
||||
},
|
||||
device.UserDomain(),
|
||||
cfgClient.Matrix.ServerName,
|
||||
|
|
|
|||
|
|
@ -99,9 +99,17 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
|||
if !worldReadable {
|
||||
// The room isn't world-readable so try to work out based on the
|
||||
// user's membership if we want the latest state or not.
|
||||
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("UserID is invalid")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown("Device UserID is invalid"),
|
||||
}
|
||||
}
|
||||
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: device.UserID,
|
||||
UserID: *userID,
|
||||
}, &membershipRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
|
||||
|
|
@ -140,14 +148,11 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
|||
// use the result of the previous QueryLatestEventsAndState response
|
||||
// to find the state event, if provided.
|
||||
for _, ev := range stateRes.StateEvents {
|
||||
sender := spec.UserID{}
|
||||
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
|
||||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
stateEvents = append(
|
||||
stateEvents,
|
||||
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
|
||||
synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}, ev),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
|
|
@ -172,9 +177,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
|||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
|
||||
sk := ev.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
stateEvents = append(
|
||||
stateEvents,
|
||||
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
|
||||
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -259,11 +273,19 @@ func OnIncomingStateTypeRequest(
|
|||
// membershipRes will only be populated if the room is not world-readable.
|
||||
var membershipRes api.QueryMembershipForUserResponse
|
||||
if !worldReadable {
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("UserID is invalid")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown("Device UserID is invalid"),
|
||||
}
|
||||
}
|
||||
// The room isn't world-readable so try to work out based on the
|
||||
// user's membership if we want the latest state or not.
|
||||
err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
|
||||
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: device.UserID,
|
||||
UserID: *userID,
|
||||
}, &membershipRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
|
||||
|
|
@ -344,13 +366,10 @@ func OnIncomingStateTypeRequest(
|
|||
}
|
||||
}
|
||||
|
||||
sender := spec.UserID{}
|
||||
userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
stateEvent := stateEventInStateResp{
|
||||
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
|
||||
ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}, event),
|
||||
}
|
||||
|
||||
var res interface{}
|
||||
|
|
|
|||
|
|
@ -59,7 +59,15 @@ func UpgradeRoom(
|
|||
}
|
||||
}
|
||||
|
||||
newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, device.UserID, gomatrixserverlib.RoomVersion(r.NewVersion))
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("device UserID is invalid")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, *userID, gomatrixserverlib.RoomVersion(r.NewVersion))
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
case roomserverAPI.ErrNotAllowed:
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ func GetEventAuth(
|
|||
if event.RoomID() != roomID {
|
||||
return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
|
||||
}
|
||||
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID)
|
||||
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,10 +35,6 @@ func GetEvent(
|
|||
eventID string,
|
||||
origin spec.ServerName,
|
||||
) util.JSONResponse {
|
||||
err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID)
|
||||
if err != nil {
|
||||
return *err
|
||||
}
|
||||
// /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string,
|
||||
// which results in `QueryEventsByID` to first get the event and use that to determine the roomID.
|
||||
event, err := fetchEvent(ctx, rsAPI, "", eventID)
|
||||
|
|
@ -46,6 +42,11 @@ func GetEvent(
|
|||
return *err
|
||||
}
|
||||
|
||||
err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
|
||||
if err != nil {
|
||||
return *err
|
||||
}
|
||||
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{
|
||||
Origin: origin,
|
||||
OriginServerTS: spec.AsTimestamp(time.Now()),
|
||||
|
|
@ -62,8 +63,9 @@ func allowedToSeeEvent(
|
|||
origin spec.ServerName,
|
||||
rsAPI api.FederationRoomserverAPI,
|
||||
eventID string,
|
||||
roomID string,
|
||||
) *util.JSONResponse {
|
||||
allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID)
|
||||
allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID, roomID)
|
||||
if err != nil {
|
||||
resErr := util.ErrorResponse(err)
|
||||
return &resErr
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ func getState(
|
|||
if event.RoomID() != roomID {
|
||||
return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
|
||||
}
|
||||
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID)
|
||||
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
|
||||
if resErr != nil {
|
||||
return nil, nil, resErr
|
||||
}
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -22,7 +22,7 @@ require (
|
|||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077
|
||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
||||
github.com/mattn/go-sqlite3 v1.14.16
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
|||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
||||
|
|
|
|||
|
|
@ -34,11 +34,11 @@ func (e ErrNotAllowed) Error() string {
|
|||
|
||||
type RestrictedJoinAPI interface {
|
||||
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
|
||||
InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error)
|
||||
RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error)
|
||||
InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error)
|
||||
RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error)
|
||||
QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
|
||||
QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
|
||||
UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error)
|
||||
UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, senderID spec.SenderID) (bool, error)
|
||||
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
|
||||
}
|
||||
|
||||
|
|
@ -191,7 +191,7 @@ type ClientRoomserverAPI interface {
|
|||
|
||||
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
|
||||
// PerformRoomUpgrade upgrades a room to a newer version
|
||||
PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
|
||||
PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
|
||||
PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
|
||||
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
|
||||
PerformAdminPurgeRoom(ctx context.Context, roomID string) error
|
||||
|
|
@ -228,6 +228,7 @@ type FederationRoomserverAPI interface {
|
|||
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
||||
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
|
||||
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
|
||||
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
||||
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
|
||||
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
|
||||
|
|
@ -238,15 +239,13 @@ type FederationRoomserverAPI interface {
|
|||
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
|
||||
// the state and auth chain to return.
|
||||
QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error
|
||||
// Query if we think we're still in a room.
|
||||
QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
|
||||
QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error
|
||||
// Query missing events for a room from roomserver
|
||||
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
|
||||
// Query whether a server is allowed to see an event
|
||||
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error)
|
||||
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error)
|
||||
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
||||
QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error)
|
||||
QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
|
||||
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
||||
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
|
||||
|
||||
|
|
@ -254,12 +253,6 @@ type FederationRoomserverAPI interface {
|
|||
// Query a given amount (or less) of events prior to a given set of events.
|
||||
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
||||
|
||||
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
|
||||
InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error)
|
||||
QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
|
||||
UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error)
|
||||
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
|
||||
|
||||
IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error)
|
||||
StateQuerier() gomatrixserverlib.StateQuerier
|
||||
}
|
||||
|
|
|
|||
|
|
@ -215,8 +215,10 @@ type OutputNewInviteEvent struct {
|
|||
type OutputRetireInviteEvent struct {
|
||||
// The ID of the "m.room.member" invite event.
|
||||
EventID string
|
||||
// The target user ID of the "m.room.member" invite event that was retired.
|
||||
TargetUserID string
|
||||
// The room ID of the "m.room.member" invite event.
|
||||
RoomID string
|
||||
// The target sender ID of the "m.room.member" invite event that was retired.
|
||||
TargetSenderID spec.SenderID
|
||||
// Optional event ID of the event that replaced the invite.
|
||||
// This can be empty if the invite was rejected locally and we were unable
|
||||
// to reach the server that originally sent the invite.
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ type PerformJoinRequest struct {
|
|||
}
|
||||
|
||||
type PerformLeaveRequest struct {
|
||||
RoomID string `json:"room_id"`
|
||||
UserID string `json:"user_id"`
|
||||
RoomID string
|
||||
Leaver spec.UserID
|
||||
}
|
||||
|
||||
type PerformLeaveResponse struct {
|
||||
|
|
|
|||
|
|
@ -113,9 +113,9 @@ type QueryEventsByIDResponse struct {
|
|||
// QueryMembershipForUserRequest is a request to QueryMembership
|
||||
type QueryMembershipForUserRequest struct {
|
||||
// ID of the room to fetch membership from
|
||||
RoomID string `json:"room_id"`
|
||||
RoomID string
|
||||
// ID of the user for whom membership is requested
|
||||
UserID string `json:"user_id"`
|
||||
UserID spec.UserID
|
||||
}
|
||||
|
||||
// QueryMembershipForUserResponse is a response to QueryMembership
|
||||
|
|
@ -145,7 +145,7 @@ type QueryMembershipsForRoomRequest struct {
|
|||
// Optional - ID of the user sending the request, for checking if the
|
||||
// user is allowed to see the memberships. If not specified then all
|
||||
// room memberships will be returned.
|
||||
Sender string `json:"sender"`
|
||||
SenderID spec.SenderID `json:"sender"`
|
||||
}
|
||||
|
||||
// QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom
|
||||
|
|
@ -448,11 +448,11 @@ func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.Ro
|
|||
return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey)
|
||||
}
|
||||
|
||||
func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) {
|
||||
return rq.Roomserver.InvitePending(ctx, roomID, userID)
|
||||
func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
|
||||
return rq.Roomserver.InvitePending(ctx, roomID, senderID)
|
||||
}
|
||||
|
||||
func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
|
||||
func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
|
||||
roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID)
|
||||
if err != nil || roomInfo == nil || roomInfo.IsStub() {
|
||||
return nil, err
|
||||
|
|
@ -468,7 +468,7 @@ func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID sp
|
|||
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
|
||||
}
|
||||
|
||||
userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID)
|
||||
userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
|
||||
return nil, fmt.Errorf("InternalServerError: %w", err)
|
||||
|
|
@ -492,12 +492,8 @@ type MembershipQuerier struct {
|
|||
}
|
||||
|
||||
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
|
||||
req := QueryMembershipForUserRequest{
|
||||
RoomID: roomID.String(),
|
||||
UserID: string(senderID),
|
||||
}
|
||||
res := QueryMembershipForUserResponse{}
|
||||
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res)
|
||||
err := mq.Roomserver.QueryMembershipForSenderID(ctx, roomID, senderID, &res)
|
||||
|
||||
membership := ""
|
||||
if err == nil {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
|
@ -22,6 +25,7 @@ import (
|
|||
// IsServerAllowed returns true if the server is allowed to see events in the room
|
||||
// at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87
|
||||
func IsServerAllowed(
|
||||
ctx context.Context, db storage.RoomDatabase,
|
||||
serverName spec.ServerName,
|
||||
serverCurrentlyInRoom bool,
|
||||
authEvents []gomatrixserverlib.PDU,
|
||||
|
|
@ -37,7 +41,7 @@ func IsServerAllowed(
|
|||
return true
|
||||
}
|
||||
// 2. If the user's membership was join, allow.
|
||||
joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Join)
|
||||
joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join)
|
||||
if joinedUserExists {
|
||||
return true
|
||||
}
|
||||
|
|
@ -46,7 +50,7 @@ func IsServerAllowed(
|
|||
return true
|
||||
}
|
||||
// 4. If the user's membership was invite, and the history_visibility was set to invited, allow.
|
||||
invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Invite)
|
||||
invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite)
|
||||
if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited {
|
||||
return true
|
||||
}
|
||||
|
|
@ -70,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver
|
|||
return visibility
|
||||
}
|
||||
|
||||
func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
|
||||
func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
|
||||
for _, ev := range authEvents {
|
||||
if ev.Type() != spec.MRoomMember {
|
||||
continue
|
||||
|
|
@ -85,12 +89,12 @@ func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []go
|
|||
continue
|
||||
}
|
||||
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', *stateKey)
|
||||
userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if domain == serverName {
|
||||
if userID.Domain() == serverName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,23 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
type FakeStorageDB struct {
|
||||
storage.RoomDatabase
|
||||
}
|
||||
|
||||
func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return spec.NewUserID(string(senderID), true)
|
||||
}
|
||||
|
||||
func TestIsServerAllowed(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
|
||||
|
|
@ -77,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) {
|
|||
authEvents = append(authEvents, ev.PDU)
|
||||
}
|
||||
|
||||
if got := IsServerAllowed(tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
|
||||
if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
|
||||
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
|
@ -55,9 +54,10 @@ func UpdateToInviteMembership(
|
|||
Type: api.OutputTypeRetireInviteEvent,
|
||||
RetireInviteEvent: &api.OutputRetireInviteEvent{
|
||||
EventID: eventID,
|
||||
RoomID: add.RoomID(),
|
||||
Membership: spec.Join,
|
||||
RetiredByEventID: add.EventID(),
|
||||
TargetUserID: *add.StateKey(),
|
||||
TargetSenderID: spec.SenderID(*add.StateKey()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
@ -94,13 +94,13 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
|
|||
for i := range events {
|
||||
gmslEvents[i] = events[i].PDU
|
||||
}
|
||||
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, spec.Join), nil
|
||||
return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil
|
||||
}
|
||||
|
||||
func IsInvitePending(
|
||||
ctx context.Context, db storage.Database,
|
||||
roomID, userID string,
|
||||
) (bool, string, string, gomatrixserverlib.PDU, error) {
|
||||
roomID string, senderID spec.SenderID,
|
||||
) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) {
|
||||
// Look up the room NID for the supplied room ID.
|
||||
info, err := db.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
|
|
@ -111,13 +111,13 @@ func IsInvitePending(
|
|||
}
|
||||
|
||||
// Look up the state key NID for the supplied user ID.
|
||||
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
|
||||
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)})
|
||||
if err != nil {
|
||||
return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
||||
}
|
||||
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
||||
targetUserNID, targetUserFound := targetUserNIDs[string(senderID)]
|
||||
if !targetUserFound {
|
||||
return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
||||
return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs)
|
||||
}
|
||||
|
||||
// Let's see if we have an event active for the user in the room. If
|
||||
|
|
@ -156,7 +156,7 @@ func IsInvitePending(
|
|||
|
||||
event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false)
|
||||
|
||||
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err
|
||||
return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err
|
||||
}
|
||||
|
||||
// GetMembershipsAtState filters the state events to
|
||||
|
|
@ -264,7 +264,7 @@ func LoadStateEvents(
|
|||
}
|
||||
|
||||
func CheckServerAllowedToSeeEvent(
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, isServerInRoom bool,
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
|
||||
) (bool, error) {
|
||||
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
|
||||
switch err {
|
||||
|
|
@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
|
|||
case tables.OptimisationNotSupportedError:
|
||||
// The database engine didn't support this optimisation, so fall back to using
|
||||
// the old and slow method
|
||||
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName)
|
||||
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -288,11 +288,11 @@ func CheckServerAllowedToSeeEvent(
|
|||
return false, err
|
||||
}
|
||||
}
|
||||
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
||||
return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil
|
||||
}
|
||||
|
||||
func slowGetHistoryVisibilityState(
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName,
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
|
||||
) ([]gomatrixserverlib.PDU, error) {
|
||||
roomState := state.NewStateResolution(db, info)
|
||||
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
||||
|
|
@ -319,8 +319,13 @@ func slowGetHistoryVisibilityState(
|
|||
// then we'll filter it out. This does preserve state keys that
|
||||
// are "" since these will contain history visibility etc.
|
||||
for nid, key := range stateKeys {
|
||||
if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
|
||||
delete(stateKeys, nid)
|
||||
if key != "" {
|
||||
userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
|
||||
if err == nil && userID != nil {
|
||||
if userID.Domain() != serverName {
|
||||
delete(stateKeys, nid)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -410,7 +415,7 @@ BFSLoop:
|
|||
// hasn't been seen before.
|
||||
if !visited[pre] {
|
||||
visited[pre] = true
|
||||
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
|
||||
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
||||
"Error checking if allowed to see event",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
|
|
@ -58,12 +59,12 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
|
|||
}
|
||||
|
||||
// Alice should have no pending invites and should have a NID
|
||||
pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID)
|
||||
pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, spec.SenderID(alice.ID))
|
||||
assert.NoError(t, err, "failed to get pending invites")
|
||||
assert.False(t, pendingInvite, "unexpected pending invite")
|
||||
|
||||
// Bob should have no pending invites and receive a new NID
|
||||
pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID)
|
||||
pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, spec.SenderID(bob.ID))
|
||||
assert.NoError(t, err, "failed to get pending invites")
|
||||
assert.False(t, pendingInvite, "unexpected pending invite")
|
||||
})
|
||||
|
|
|
|||
|
|
@ -842,17 +842,15 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
|
|||
continue
|
||||
}
|
||||
|
||||
// TODO: pseudoIDs: get userID for room using state key (which is now senderID)
|
||||
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
|
||||
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: pseudoIDs: query account by state key (which is now senderID)
|
||||
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
|
||||
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: senderDomain,
|
||||
Localpart: memberUserID.Local(),
|
||||
ServerName: memberUserID.Domain(),
|
||||
}, accountRes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -896,8 +894,8 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
|
|||
inputEvents = append(inputEvents, api.InputRoomEvent{
|
||||
Kind: api.KindNew,
|
||||
Event: event,
|
||||
Origin: senderDomain,
|
||||
SendAsServer: string(senderDomain),
|
||||
Origin: memberUserID.Domain(),
|
||||
SendAsServer: string(memberUserID.Domain()),
|
||||
})
|
||||
prevEvents = []string{event.EventID()}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
|
|
@ -72,7 +71,7 @@ func (r *Inputer) updateMemberships(
|
|||
if change.addedEventNID != 0 {
|
||||
ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID)
|
||||
}
|
||||
if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
|
||||
if updates, err = r.updateMembership(ctx, updater, targetUserNID, re, ae, updates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
|
@ -80,6 +79,7 @@ func (r *Inputer) updateMemberships(
|
|||
}
|
||||
|
||||
func (r *Inputer) updateMembership(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
targetUserNID types.EventStateKeyNID,
|
||||
remove, add *types.Event,
|
||||
|
|
@ -97,7 +97,7 @@ func (r *Inputer) updateMembership(
|
|||
|
||||
var targetLocal bool
|
||||
if add != nil {
|
||||
targetLocal = r.isLocalTarget(add)
|
||||
targetLocal = r.isLocalTarget(ctx, add)
|
||||
}
|
||||
|
||||
mu, err := updater.MembershipUpdater(targetUserNID, targetLocal)
|
||||
|
|
@ -136,11 +136,14 @@ func (r *Inputer) updateMembership(
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Inputer) isLocalTarget(event *types.Event) bool {
|
||||
func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
|
||||
isTargetLocalUser := false
|
||||
if statekey := event.StateKey(); statekey != nil {
|
||||
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
|
||||
isTargetLocalUser = domain == r.ServerName
|
||||
userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey))
|
||||
if err != nil || userID == nil {
|
||||
return isTargetLocalUser
|
||||
}
|
||||
isTargetLocalUser = userID.Domain() == r.ServerName
|
||||
}
|
||||
return isTargetLocalUser
|
||||
}
|
||||
|
|
@ -161,9 +164,10 @@ func updateToJoinMembership(
|
|||
Type: api.OutputTypeRetireInviteEvent,
|
||||
RetireInviteEvent: &api.OutputRetireInviteEvent{
|
||||
EventID: eventID,
|
||||
RoomID: add.RoomID(),
|
||||
Membership: spec.Join,
|
||||
RetiredByEventID: add.EventID(),
|
||||
TargetUserID: *add.StateKey(),
|
||||
TargetSenderID: spec.SenderID(*add.StateKey()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
@ -187,9 +191,10 @@ func updateToLeaveMembership(
|
|||
Type: api.OutputTypeRetireInviteEvent,
|
||||
RetireInviteEvent: &api.OutputRetireInviteEvent{
|
||||
EventID: eventID,
|
||||
RoomID: add.RoomID(),
|
||||
Membership: newMembership,
|
||||
RetiredByEventID: add.EventID(),
|
||||
TargetUserID: *add.StateKey(),
|
||||
TargetSenderID: spec.SenderID(*add.StateKey()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -149,11 +149,11 @@ func (r *Admin) PerformAdminEvacuateUser(
|
|||
ctx context.Context,
|
||||
userID string,
|
||||
) (affected []string, err error) {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
fullUserID, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||
if !r.Cfg.Matrix.IsLocalServerName(fullUserID.Domain()) {
|
||||
return nil, fmt.Errorf("can only evacuate local users using this endpoint")
|
||||
}
|
||||
|
||||
|
|
@ -172,7 +172,7 @@ func (r *Admin) PerformAdminEvacuateUser(
|
|||
for _, roomID := range allRooms {
|
||||
leaveReq := &api.PerformLeaveRequest{
|
||||
RoomID: roomID,
|
||||
UserID: userID,
|
||||
Leaver: *fullUserID,
|
||||
}
|
||||
leaveRes := &api.PerformLeaveResponse{}
|
||||
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)
|
||||
|
|
|
|||
|
|
@ -582,7 +582,7 @@ func joinEventsFromHistoryVisibility(
|
|||
}
|
||||
|
||||
// Can we see events in the room?
|
||||
canSeeEvents := auth.IsServerAllowed(thisServer, true, events)
|
||||
canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events)
|
||||
visibility := auth.HistoryVisibilityForRoom(events)
|
||||
if !canSeeEvents {
|
||||
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
|
||||
|
|
|
|||
|
|
@ -63,9 +63,17 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
}
|
||||
}
|
||||
}
|
||||
createContent["creator"] = userID.String()
|
||||
senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
|
||||
return "", &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
createContent["creator"] = senderID
|
||||
createContent["room_version"] = createRequest.RoomVersion
|
||||
powerLevelContent := eventutil.InitialPowerLevelsContent(userID.String())
|
||||
powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID))
|
||||
joinRuleContent := gomatrixserverlib.JoinRuleContent{
|
||||
JoinRule: spec.Invite,
|
||||
}
|
||||
|
|
@ -121,7 +129,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
}
|
||||
membershipEvent := gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomMember,
|
||||
StateKey: userID.String(),
|
||||
StateKey: string(senderID),
|
||||
Content: gomatrixserverlib.MemberContent{
|
||||
Membership: spec.Join,
|
||||
DisplayName: createRequest.UserDisplayName,
|
||||
|
|
@ -270,7 +278,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
|
||||
var builtEvents []*types.HeaderedEvent
|
||||
authEvents := gomatrixserverlib.NewAuthEvents(nil)
|
||||
senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
|
||||
return "", &util.JSONResponse{
|
||||
|
|
|
|||
|
|
@ -134,12 +134,12 @@ func (r *Inviter) PerformInvite(
|
|||
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
|
||||
}
|
||||
|
||||
if event.StateKey() == nil {
|
||||
if event.StateKey() == nil || *event.StateKey() == "" {
|
||||
return fmt.Errorf("invite must be a state event")
|
||||
}
|
||||
invitedUser, err := spec.NewUserID(*event.StateKey(), true)
|
||||
if err != nil {
|
||||
return spec.InvalidParam("The user ID is invalid")
|
||||
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if err != nil || invitedUser == nil {
|
||||
return spec.InvalidParam("Could not find the matching senderID for this user")
|
||||
}
|
||||
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
|
||||
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ func (r *Joiner) performJoinRoomByID(
|
|||
}
|
||||
|
||||
// Get the domain part of the room ID.
|
||||
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias)
|
||||
roomID, err := spec.NewRoomID(req.RoomIDOrAlias)
|
||||
if err != nil {
|
||||
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
|
||||
}
|
||||
|
|
@ -170,8 +170,8 @@ func (r *Joiner) performJoinRoomByID(
|
|||
// If the server name in the room ID isn't ours then it's a
|
||||
// possible candidate for finding the room via federation. Add
|
||||
// it to the list of servers to try.
|
||||
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||
req.ServerNames = append(req.ServerNames, domain)
|
||||
if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
|
||||
req.ServerNames = append(req.ServerNames, roomID.Domain())
|
||||
}
|
||||
|
||||
// Prepare the template for the join event.
|
||||
|
|
@ -203,7 +203,7 @@ func (r *Joiner) performJoinRoomByID(
|
|||
req.Content = map[string]interface{}{}
|
||||
}
|
||||
req.Content["membership"] = spec.Join
|
||||
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil {
|
||||
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
|
||||
return "", "", aerr
|
||||
} else if authorisedVia != "" {
|
||||
req.Content["join_authorised_via_users_server"] = authorisedVia
|
||||
|
|
@ -226,17 +226,17 @@ func (r *Joiner) performJoinRoomByID(
|
|||
|
||||
// Force a federated join if we're dealing with a pending invite
|
||||
// and we aren't in the room.
|
||||
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID)
|
||||
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
|
||||
if err == nil && !serverInRoom && isInvitePending {
|
||||
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
|
||||
if ierr != nil {
|
||||
return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender)
|
||||
if queryErr != nil {
|
||||
return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
|
||||
}
|
||||
|
||||
// If we were invited by someone from another server then we can
|
||||
// assume they are in the room so we can join via them.
|
||||
if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) {
|
||||
req.ServerNames = append(req.ServerNames, inviterDomain)
|
||||
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
|
||||
req.ServerNames = append(req.ServerNames, inviter.Domain())
|
||||
forceFederatedJoin = true
|
||||
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
|
||||
// only set unsigned if we've got a content.membership, which we _should_
|
||||
|
|
@ -298,12 +298,8 @@ func (r *Joiner) performJoinRoomByID(
|
|||
// a member of the room. This is best-effort (as in we won't
|
||||
// fail if we can't find the existing membership) because there
|
||||
// is really no harm in just sending another membership event.
|
||||
membershipReq := &api.QueryMembershipForUserRequest{
|
||||
RoomID: req.RoomIDOrAlias,
|
||||
UserID: userID.String(),
|
||||
}
|
||||
membershipRes := &api.QueryMembershipForUserResponse{}
|
||||
_ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes)
|
||||
_ = r.Queryer.QueryMembershipForSenderID(ctx, *roomID, senderID, membershipRes)
|
||||
|
||||
// If we haven't already joined the room then send an event
|
||||
// into the room changing our membership status.
|
||||
|
|
@ -328,7 +324,7 @@ func (r *Joiner) performJoinRoomByID(
|
|||
// The room doesn't exist locally. If the room ID looks like it should
|
||||
// be ours then this probably means that we've nuked our database at
|
||||
// some point.
|
||||
if r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||
if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
|
||||
// If there are no more server names to try then give up here.
|
||||
// Otherwise we'll try a federated join as normal, since it's quite
|
||||
// possible that the room still exists on other servers.
|
||||
|
|
@ -376,15 +372,12 @@ func (r *Joiner) performFederatedJoinRoomByID(
|
|||
func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
|
||||
ctx context.Context,
|
||||
joinReq *rsAPI.PerformJoinRequest,
|
||||
senderID spec.SenderID,
|
||||
) (string, error) {
|
||||
roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
userID, err := spec.NewUserID(joinReq.UserID, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID)
|
||||
return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, senderID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,16 +53,12 @@ func (r *Leaver) PerformLeave(
|
|||
req *api.PerformLeaveRequest,
|
||||
res *api.PerformLeaveResponse,
|
||||
) ([]api.OutputEvent, error) {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)
|
||||
}
|
||||
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID)
|
||||
if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) {
|
||||
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
|
||||
}
|
||||
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomID,
|
||||
"user_id": req.UserID,
|
||||
"user_id": req.Leaver.String(),
|
||||
})
|
||||
logger.Info("User requested to leave join")
|
||||
if strings.HasPrefix(req.RoomID, "!") {
|
||||
|
|
@ -82,21 +78,26 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
req *api.PerformLeaveRequest,
|
||||
res *api.PerformLeaveResponse, // nolint:unparam
|
||||
) ([]api.OutputEvent, error) {
|
||||
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
|
||||
}
|
||||
|
||||
// If there's an invite outstanding for the room then respond to
|
||||
// that.
|
||||
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
|
||||
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
|
||||
if err == nil && isInvitePending {
|
||||
_, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser)
|
||||
if serr != nil {
|
||||
return nil, fmt.Errorf("sender %q is invalid", senderUser)
|
||||
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
|
||||
if serr != nil || sender == nil {
|
||||
return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
|
||||
}
|
||||
if !r.Cfg.Matrix.IsLocalServerName(senderDomain) {
|
||||
return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID)
|
||||
if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
|
||||
return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, leaver)
|
||||
}
|
||||
// check that this is not a "server notice room"
|
||||
accData := &userapi.QueryAccountDataResponse{}
|
||||
if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
|
||||
UserID: req.UserID,
|
||||
UserID: req.Leaver.String(),
|
||||
RoomID: req.RoomID,
|
||||
DataType: "m.tag",
|
||||
}, accData); err != nil {
|
||||
|
|
@ -127,7 +128,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
StateToFetch: []gomatrixserverlib.StateKeyTuple{
|
||||
{
|
||||
EventType: spec.MRoomMember,
|
||||
StateKey: req.UserID,
|
||||
StateKey: string(leaver),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -141,26 +142,18 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
|
||||
// Now let's see if the user is in the room.
|
||||
if len(latestRes.StateEvents) == 0 {
|
||||
return nil, fmt.Errorf("user %q is not a member of room %q", req.UserID, req.RoomID)
|
||||
return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
|
||||
}
|
||||
membership, err := latestRes.StateEvents[0].Membership()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting membership: %w", err)
|
||||
}
|
||||
if membership != spec.Join && membership != spec.Invite {
|
||||
return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership)
|
||||
return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
|
||||
}
|
||||
|
||||
// Prepare the template for the leave event.
|
||||
fullUserID, err := spec.NewUserID(req.UserID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
senderIDString := string(senderID)
|
||||
senderIDString := string(leaver)
|
||||
proto := gomatrixserverlib.ProtoEvent{
|
||||
Type: spec.MRoomMember,
|
||||
SenderID: senderIDString,
|
||||
|
|
@ -175,16 +168,13 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||
}
|
||||
|
||||
// Get the sender domain.
|
||||
senderDomain := fullUserID.Domain()
|
||||
|
||||
// We know that the user is in the room at this point so let's build
|
||||
// a leave event.
|
||||
// TODO: Check what happens if the room exists on the server
|
||||
// but everyone has since left. I suspect it does the wrong thing.
|
||||
|
||||
var buildRes rsAPI.QueryLatestEventsAndStateResponse
|
||||
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
|
||||
identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.Domain())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SigningIdentityFor: %w", err)
|
||||
}
|
||||
|
|
@ -201,8 +191,8 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
{
|
||||
Kind: api.KindNew,
|
||||
Event: event,
|
||||
Origin: senderDomain,
|
||||
SendAsServer: string(senderDomain),
|
||||
Origin: req.Leaver.Domain(),
|
||||
SendAsServer: string(req.Leaver.Domain()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -219,21 +209,17 @@ func (r *Leaver) performFederatedRejectInvite(
|
|||
ctx context.Context,
|
||||
req *api.PerformLeaveRequest,
|
||||
res *api.PerformLeaveResponse, // nolint:unparam
|
||||
senderUser, eventID string,
|
||||
inviteSender spec.UserID, eventID string,
|
||||
leaver spec.SenderID,
|
||||
) ([]api.OutputEvent, error) {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', senderUser)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("user ID %q invalid: %w", senderUser, err)
|
||||
}
|
||||
|
||||
// Ask the federation sender to perform a federated leave for us.
|
||||
leaveReq := fsAPI.PerformLeaveRequest{
|
||||
RoomID: req.RoomID,
|
||||
UserID: req.UserID,
|
||||
ServerNames: []spec.ServerName{domain},
|
||||
UserID: req.Leaver.String(),
|
||||
ServerNames: []spec.ServerName{inviteSender.Domain()},
|
||||
}
|
||||
leaveRes := fsAPI.PerformLeaveResponse{}
|
||||
if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
|
||||
if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
|
||||
// failures in PerformLeave should NEVER stop us from telling other components like the
|
||||
// sync API that the invite was withdrawn. Otherwise we can end up with stuck invites.
|
||||
util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event")
|
||||
|
|
@ -244,7 +230,7 @@ func (r *Leaver) performFederatedRejectInvite(
|
|||
util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event")
|
||||
}
|
||||
|
||||
updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, req.UserID, true, info.RoomVersion)
|
||||
updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(leaver), true, info.RoomVersion)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event")
|
||||
}
|
||||
|
|
@ -267,9 +253,10 @@ func (r *Leaver) performFederatedRejectInvite(
|
|||
{
|
||||
Type: api.OutputTypeRetireInviteEvent,
|
||||
RetireInviteEvent: &api.OutputRetireInviteEvent{
|
||||
EventID: eventID,
|
||||
Membership: "leave",
|
||||
TargetUserID: req.UserID,
|
||||
EventID: eventID,
|
||||
RoomID: req.RoomID,
|
||||
Membership: "leave",
|
||||
TargetSenderID: leaver,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
|
|
|||
|
|
@ -38,19 +38,15 @@ type Upgrader struct {
|
|||
// PerformRoomUpgrade upgrades a room from one version to another
|
||||
func (r *Upgrader) PerformRoomUpgrade(
|
||||
ctx context.Context,
|
||||
roomID, userID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (newRoomID string, err error) {
|
||||
return r.performRoomUpgrade(ctx, roomID, userID, roomVersion)
|
||||
}
|
||||
|
||||
func (r *Upgrader) performRoomUpgrade(
|
||||
ctx context.Context,
|
||||
roomID, userID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (string, error) {
|
||||
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
|
||||
if err != nil {
|
||||
return "", api.ErrNotAllowed{Err: fmt.Errorf("error validating the user ID")}
|
||||
}
|
||||
evTime := time.Now()
|
||||
|
||||
// Return an immediate error if the room does not exist
|
||||
|
|
@ -58,14 +54,20 @@ func (r *Upgrader) performRoomUpgrade(
|
|||
return "", err
|
||||
}
|
||||
|
||||
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone)
|
||||
if !r.userIsAuthorized(ctx, userID, roomID) {
|
||||
if !r.userIsAuthorized(ctx, senderID, roomID) {
|
||||
return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")}
|
||||
}
|
||||
|
||||
// TODO (#267): Check room ID doesn't clash with an existing one, and we
|
||||
// probably shouldn't be using pseudo-random strings, maybe GUIDs?
|
||||
newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain)
|
||||
newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain())
|
||||
|
||||
// Get the existing room state for the old room.
|
||||
oldRoomReq := &api.QueryLatestEventsAndStateRequest{
|
||||
|
|
@ -77,25 +79,25 @@ func (r *Upgrader) performRoomUpgrade(
|
|||
}
|
||||
|
||||
// Make the tombstone event
|
||||
tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, userID, roomID, newRoomID)
|
||||
tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, senderID, userID.Domain(), roomID, newRoomID)
|
||||
if pErr != nil {
|
||||
return "", pErr
|
||||
}
|
||||
|
||||
// Generate the initial events we need to send into the new room. This includes copied state events and bans
|
||||
// as well as the power level events needed to set up the room
|
||||
eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, roomVersion, tombstoneEvent)
|
||||
eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, senderID, roomID, roomVersion, tombstoneEvent)
|
||||
if pErr != nil {
|
||||
return "", pErr
|
||||
}
|
||||
|
||||
// Send the setup events to the new room
|
||||
if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, roomVersion, eventsToMake); pErr != nil {
|
||||
if pErr = r.sendInitialEvents(ctx, evTime, senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil {
|
||||
return "", pErr
|
||||
}
|
||||
|
||||
// 5. Send the tombstone event to the old room
|
||||
if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil {
|
||||
if pErr = r.sendHeaderedEvent(ctx, userID.Domain(), tombstoneEvent, string(userID.Domain())); pErr != nil {
|
||||
return "", pErr
|
||||
}
|
||||
|
||||
|
|
@ -105,17 +107,17 @@ func (r *Upgrader) performRoomUpgrade(
|
|||
}
|
||||
|
||||
// If the old room had a canonical alias event, it should be deleted in the old room
|
||||
if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil {
|
||||
if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, senderID, userID.Domain(), roomID); pErr != nil {
|
||||
return "", pErr
|
||||
}
|
||||
|
||||
// 4. Move local aliases to the new room
|
||||
if pErr = moveLocalAliases(ctx, roomID, newRoomID, userID, r.URSAPI); pErr != nil {
|
||||
if pErr = moveLocalAliases(ctx, roomID, newRoomID, senderID, userID, r.URSAPI); pErr != nil {
|
||||
return "", pErr
|
||||
}
|
||||
|
||||
// 6. Restrict power levels in the old room
|
||||
if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil {
|
||||
if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, senderID, userID.Domain(), roomID); pErr != nil {
|
||||
return "", pErr
|
||||
}
|
||||
|
||||
|
|
@ -130,7 +132,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma
|
|||
return oldPowerLevelsEvent.PowerLevels()
|
||||
}
|
||||
|
||||
func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error {
|
||||
func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error {
|
||||
restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID)
|
||||
if pErr != nil {
|
||||
return pErr
|
||||
|
|
@ -147,7 +149,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
|
|||
restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel
|
||||
restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel
|
||||
|
||||
restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{
|
||||
restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomPowerLevels,
|
||||
StateKey: "",
|
||||
Content: restrictedPowerLevelContent,
|
||||
|
|
@ -165,7 +167,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
|
|||
}
|
||||
|
||||
func moveLocalAliases(ctx context.Context,
|
||||
roomID, newRoomID, userID string,
|
||||
roomID, newRoomID string, senderID spec.SenderID, userID spec.UserID,
|
||||
URSAPI api.RoomserverInternalAPI,
|
||||
) (err error) {
|
||||
|
||||
|
|
@ -175,14 +177,6 @@ func moveLocalAliases(ctx context.Context,
|
|||
return fmt.Errorf("Failed to get old room aliases: %w", err)
|
||||
}
|
||||
|
||||
fullUserID, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get userID: %w", err)
|
||||
}
|
||||
senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get senderID: %w", err)
|
||||
}
|
||||
for _, alias := range aliasRes.Aliases {
|
||||
removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias}
|
||||
removeAliasRes := api.RemoveRoomAliasResponse{}
|
||||
|
|
@ -190,7 +184,7 @@ func moveLocalAliases(ctx context.Context,
|
|||
return fmt.Errorf("Failed to remove old room alias: %w", err)
|
||||
}
|
||||
|
||||
setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID}
|
||||
setAliasReq := api.SetRoomAliasRequest{UserID: userID.String(), Alias: alias, RoomID: newRoomID}
|
||||
setAliasRes := api.SetRoomAliasResponse{}
|
||||
if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil {
|
||||
return fmt.Errorf("Failed to set new room alias: %w", err)
|
||||
|
|
@ -199,7 +193,7 @@ func moveLocalAliases(ctx context.Context,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error {
|
||||
func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error {
|
||||
for _, event := range oldRoom.StateEvents {
|
||||
if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") {
|
||||
continue
|
||||
|
|
@ -217,7 +211,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
|
|||
}
|
||||
}
|
||||
|
||||
emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{
|
||||
emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomCanonicalAlias,
|
||||
Content: map[string]interface{}{},
|
||||
})
|
||||
|
|
@ -280,7 +274,7 @@ func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
|
||||
func (r *Upgrader) userIsAuthorized(ctx context.Context, senderID spec.SenderID, roomID string,
|
||||
) bool {
|
||||
plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{
|
||||
EventType: spec.MRoomPowerLevels,
|
||||
|
|
@ -295,26 +289,18 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
|
|||
}
|
||||
// Check for power level required to send tombstone event (marks the current room as obsolete),
|
||||
// if not found, use the StateDefault power level
|
||||
fullUserID, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true)
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) {
|
||||
func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, senderID spec.SenderID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) {
|
||||
state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents))
|
||||
for _, event := range oldRoom.StateEvents {
|
||||
if event.StateKey() == nil {
|
||||
// This shouldn't ever happen, but better to be safe than sorry.
|
||||
continue
|
||||
}
|
||||
if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) {
|
||||
if event.Type() == spec.MRoomMember && !event.StateKeyEquals(string(senderID)) {
|
||||
// With the exception of bans which we do want to copy, we
|
||||
// should ignore membership events that aren't our own, as event auth will
|
||||
// prevent us from being able to create membership events on behalf of other
|
||||
|
|
@ -330,6 +316,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
|
|||
}
|
||||
}
|
||||
// skip events that rely on a specific user being present
|
||||
// TODO: What to do here for pseudoIDs? It's checking non-member events for state keys with userIDs.
|
||||
sKey := *event.StateKey()
|
||||
if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" {
|
||||
continue
|
||||
|
|
@ -340,10 +327,10 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
|
|||
// The following events are ones that we are going to override manually
|
||||
// in the following section.
|
||||
override := map[gomatrixserverlib.StateKeyTuple]struct{}{
|
||||
{EventType: spec.MRoomCreate, StateKey: ""}: {},
|
||||
{EventType: spec.MRoomMember, StateKey: userID}: {},
|
||||
{EventType: spec.MRoomPowerLevels, StateKey: ""}: {},
|
||||
{EventType: spec.MRoomJoinRules, StateKey: ""}: {},
|
||||
{EventType: spec.MRoomCreate, StateKey: ""}: {},
|
||||
{EventType: spec.MRoomMember, StateKey: string(senderID)}: {},
|
||||
{EventType: spec.MRoomPowerLevels, StateKey: ""}: {},
|
||||
{EventType: spec.MRoomJoinRules, StateKey: ""}: {},
|
||||
}
|
||||
|
||||
// The overridden events are essential events that must be present in the
|
||||
|
|
@ -355,7 +342,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
|
|||
}
|
||||
|
||||
oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}]
|
||||
oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: userID}]
|
||||
oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: string(senderID)}]
|
||||
oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}]
|
||||
oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}]
|
||||
|
||||
|
|
@ -364,7 +351,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
|
|||
// in the create event (such as for the room types MSC).
|
||||
newCreateContent := map[string]interface{}{}
|
||||
_ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent)
|
||||
newCreateContent["creator"] = userID
|
||||
newCreateContent["creator"] = string(senderID)
|
||||
newCreateContent["room_version"] = newVersion
|
||||
newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{
|
||||
EventID: tombstoneEvent.EventID(),
|
||||
|
|
@ -385,7 +372,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
|
|||
newMembershipContent["membership"] = spec.Join
|
||||
newMembershipEvent := gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomMember,
|
||||
StateKey: userID,
|
||||
StateKey: string(senderID),
|
||||
Content: newMembershipContent,
|
||||
}
|
||||
|
||||
|
|
@ -400,14 +387,6 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
|
|||
return nil, fmt.Errorf("Power level event content was invalid")
|
||||
}
|
||||
|
||||
fullUserID, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID)
|
||||
|
||||
// Now do the join rules event, same as the create and membership
|
||||
|
|
@ -470,21 +449,13 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
|
|||
return eventsToMake, nil
|
||||
}
|
||||
|
||||
func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error {
|
||||
func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error {
|
||||
var err error
|
||||
var builtEvents []*types.HeaderedEvent
|
||||
authEvents := gomatrixserverlib.NewAuthEvents(nil)
|
||||
for i, e := range eventsToMake {
|
||||
depth := i + 1 // depth starts at 1
|
||||
|
||||
fullUserID, userIDErr := spec.NewUserID(userID, true)
|
||||
if userIDErr != nil {
|
||||
return userIDErr
|
||||
}
|
||||
senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID)
|
||||
if queryErr != nil {
|
||||
return queryErr
|
||||
}
|
||||
proto := gomatrixserverlib.ProtoEvent{
|
||||
SenderID: string(senderID),
|
||||
RoomID: newRoomID,
|
||||
|
|
@ -549,7 +520,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
|
|||
func (r *Upgrader) makeTombstoneEvent(
|
||||
ctx context.Context,
|
||||
evTime time.Time,
|
||||
userID, roomID, newRoomID string,
|
||||
senderID spec.SenderID, senderDomain spec.ServerName, roomID, newRoomID string,
|
||||
) (*types.HeaderedEvent, error) {
|
||||
content := map[string]interface{}{
|
||||
"body": "This room has been replaced",
|
||||
|
|
@ -559,30 +530,21 @@ func (r *Upgrader) makeTombstoneEvent(
|
|||
Type: "m.room.tombstone",
|
||||
Content: content,
|
||||
}
|
||||
return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event)
|
||||
return r.makeHeaderedEvent(ctx, evTime, senderID, senderDomain, roomID, event)
|
||||
}
|
||||
|
||||
func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) {
|
||||
fullUserID, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, senderID spec.SenderID, senderDomain spec.ServerName, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) {
|
||||
proto := gomatrixserverlib.ProtoEvent{
|
||||
SenderID: string(senderID),
|
||||
RoomID: roomID,
|
||||
Type: event.Type,
|
||||
StateKey: &event.StateKey,
|
||||
}
|
||||
err = proto.SetContent(event.Content)
|
||||
err := proto.SetContent(event.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err)
|
||||
}
|
||||
// Get the sender domain.
|
||||
senderDomain := fullUserID.Domain()
|
||||
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ type Queryer struct {
|
|||
Cfg *config.Dendrite
|
||||
}
|
||||
|
||||
func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
|
||||
func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
|
||||
roomInfo, err := r.QueryRoomInfo(ctx, roomID)
|
||||
if err != nil || roomInfo == nil || roomInfo.IsStub() {
|
||||
return nil, err
|
||||
|
|
@ -64,7 +64,7 @@ func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID
|
|||
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
|
||||
}
|
||||
|
||||
userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID)
|
||||
userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
|
||||
return nil, fmt.Errorf("InternalServerError: %w", err)
|
||||
|
|
@ -220,13 +220,14 @@ func (r *Queryer) QueryEventsByID(
|
|||
return nil
|
||||
}
|
||||
|
||||
// QueryMembershipForUser implements api.RoomserverInternalAPI
|
||||
func (r *Queryer) QueryMembershipForUser(
|
||||
// QueryMembershipForSenderID implements api.RoomserverInternalAPI
|
||||
func (r *Queryer) QueryMembershipForSenderID(
|
||||
ctx context.Context,
|
||||
request *api.QueryMembershipForUserRequest,
|
||||
roomID spec.RoomID,
|
||||
senderID spec.SenderID,
|
||||
response *api.QueryMembershipForUserResponse,
|
||||
) error {
|
||||
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||
info, err := r.DB.RoomInfo(ctx, roomID.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -236,7 +237,7 @@ func (r *Queryer) QueryMembershipForUser(
|
|||
}
|
||||
response.RoomExists = true
|
||||
|
||||
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID)
|
||||
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -264,6 +265,24 @@ func (r *Queryer) QueryMembershipForUser(
|
|||
return err
|
||||
}
|
||||
|
||||
// QueryMembershipForUser implements api.RoomserverInternalAPI
|
||||
func (r *Queryer) QueryMembershipForUser(
|
||||
ctx context.Context,
|
||||
request *api.QueryMembershipForUserRequest,
|
||||
response *api.QueryMembershipForUserResponse,
|
||||
) error {
|
||||
senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
roomID, err := spec.NewRoomID(request.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
|
||||
}
|
||||
|
||||
// QueryMembershipAtEvent returns the known memberships at a given event.
|
||||
// If the state before an event is not known, an empty list will be returned
|
||||
// for that event instead.
|
||||
|
|
@ -373,7 +392,7 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||
// If no sender is specified then we will just return the entire
|
||||
// set of memberships for the room, regardless of whether a specific
|
||||
// user is allowed to see them or not.
|
||||
if request.Sender == "" {
|
||||
if request.SenderID == "" {
|
||||
var events []types.Event
|
||||
var eventNIDs []types.EventNID
|
||||
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly)
|
||||
|
|
@ -388,18 +407,15 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||
return fmt.Errorf("r.DB.Events: %w", err)
|
||||
}
|
||||
for _, event := range events {
|
||||
sender := spec.UserID{}
|
||||
userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
if queryErr == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
|
||||
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}, event)
|
||||
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender)
|
||||
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -442,12 +458,9 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||
}
|
||||
|
||||
for _, event := range events {
|
||||
sender := spec.UserID{}
|
||||
userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
|
||||
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}, event)
|
||||
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
||||
}
|
||||
|
||||
|
|
@ -489,6 +502,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
|
|||
ctx context.Context,
|
||||
serverName spec.ServerName,
|
||||
eventID string,
|
||||
roomID string,
|
||||
) (allowed bool, err error) {
|
||||
events, err := r.DB.EventNIDs(ctx, []string{eventID})
|
||||
if err != nil {
|
||||
|
|
@ -518,7 +532,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
|
|||
}
|
||||
|
||||
return helpers.CheckServerAllowedToSeeEvent(
|
||||
ctx, r.DB, info, eventID, serverName, isInRoom,
|
||||
ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -909,8 +923,8 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) {
|
||||
pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String())
|
||||
func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
|
||||
pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID)
|
||||
return pending, err
|
||||
}
|
||||
|
||||
|
|
@ -926,8 +940,8 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve
|
|||
return res, err
|
||||
}
|
||||
|
||||
func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) {
|
||||
_, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String())
|
||||
func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) {
|
||||
_, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID)
|
||||
return isIn, err
|
||||
}
|
||||
|
||||
|
|
@ -957,7 +971,7 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
|
|||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) {
|
||||
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
|
||||
// Look up if we know anything about the room. If it doesn't exist
|
||||
// or is a stub entry then we can't do anything.
|
||||
roomInfo, err := r.DB.RoomInfo(ctx, roomID.String())
|
||||
|
|
@ -972,7 +986,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
|
|||
return "", err
|
||||
}
|
||||
|
||||
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID)
|
||||
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
|
||||
}
|
||||
|
||||
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
|
||||
|
|
|
|||
|
|
@ -722,7 +722,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
|
|||
|
||||
roomID, _ := spec.NewRoomID(testRoom.ID)
|
||||
userID, _ := spec.NewUserID(bob.ID, true)
|
||||
got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID)
|
||||
got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, spec.SenderID(userID.String()))
|
||||
if tc.wantError && err == nil {
|
||||
t.Fatal("expected error, got none")
|
||||
}
|
||||
|
|
@ -821,17 +821,6 @@ func TestUpgrade(t *testing.T) {
|
|||
validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI)
|
||||
wantNewRoom bool
|
||||
}{
|
||||
{
|
||||
name: "invalid userID",
|
||||
upgradeUser: "!notvalid:test",
|
||||
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
|
||||
room := test.NewRoom(t, alice)
|
||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||
t.Errorf("failed to send events: %v", err)
|
||||
}
|
||||
return room.ID
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid roomID",
|
||||
upgradeUser: alice.ID,
|
||||
|
|
@ -1049,7 +1038,11 @@ func TestUpgrade(t *testing.T) {
|
|||
}
|
||||
roomID := tc.roomFunc(rsAPI)
|
||||
|
||||
newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, tc.upgradeUser, version.DefaultRoomVersion())
|
||||
userID, err := spec.NewUserID(tc.upgradeUser, true)
|
||||
if err != nil {
|
||||
t.Fatalf("upgrade userID is invalid")
|
||||
}
|
||||
newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, *userID, version.DefaultRoomVersion())
|
||||
if err != nil && tc.wantNewRoom {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package storage
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
|
@ -27,6 +28,7 @@ import (
|
|||
)
|
||||
|
||||
type Database interface {
|
||||
UserRoomKeys
|
||||
// Do we support processing input events for more than one room at a time?
|
||||
SupportsConcurrentRoomInputs() bool
|
||||
// RoomInfo returns room information for the given room ID, or nil if there is no room.
|
||||
|
|
@ -129,7 +131,7 @@ type Database interface {
|
|||
// in this room, along a boolean set to true if the user is still in this room,
|
||||
// false if not.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error)
|
||||
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error)
|
||||
// Lookup the membership event numeric IDs for all user that are or have
|
||||
// been members of a given room. Only lookup events of "join" membership if
|
||||
// joinOnly is set to true.
|
||||
|
|
@ -194,8 +196,22 @@ type Database interface {
|
|||
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
|
||||
}
|
||||
|
||||
type UserRoomKeys interface {
|
||||
// InsertUserRoomPrivatePublicKey inserts the given private key as well as the public key for it. This should be used
|
||||
// when creating keys locally.
|
||||
InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error)
|
||||
// InsertUserRoomPublicKey inserts the given public key, this should be used for users NOT local to this server
|
||||
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
|
||||
// SelectUserRoomPrivateKey selects the private key for the given user and room combination
|
||||
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
|
||||
// SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID.
|
||||
// If a senderKey can't be found, it is omitted in the result.
|
||||
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error)
|
||||
}
|
||||
|
||||
type RoomDatabase interface {
|
||||
EventDatabase
|
||||
UserRoomKeys
|
||||
// RoomInfo returns room information for the given room ID, or nil if there is no room.
|
||||
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
|
||||
|
|
|
|||
|
|
@ -131,6 +131,9 @@ func (d *Database) create(db *sql.DB) error {
|
|||
if err := CreateRedactionsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := CreateUserRoomKeysTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -192,6 +195,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userRoomKeys, err := PrepareUserRoomKeysTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Database = shared.Database{
|
||||
DB: db,
|
||||
EventDatabase: shared.EventDatabase{
|
||||
|
|
@ -215,6 +223,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
MembershipTable: membership,
|
||||
PublishedTable: published,
|
||||
Purge: purge,
|
||||
UserRoomKeyTable: userRoomKeys,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
132
roomserver/storage/postgres/user_room_keys_table.go
Normal file
132
roomserver/storage/postgres/user_room_keys_table.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const userRoomKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_user_room_keys (
|
||||
user_nid INTEGER NOT NULL,
|
||||
room_nid INTEGER NOT NULL,
|
||||
pseudo_id_key BYTEA NULL, -- may be null for users not local to the server
|
||||
pseudo_id_pub_key BYTEA NOT NULL,
|
||||
CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid)
|
||||
);
|
||||
`
|
||||
|
||||
const insertUserRoomPrivateKeySQL = `
|
||||
INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key
|
||||
RETURNING (pseudo_id_key)
|
||||
`
|
||||
|
||||
const insertUserRoomPublicKeySQL = `
|
||||
INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3)
|
||||
ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_pub_key = $3
|
||||
RETURNING (pseudo_id_pub_key)
|
||||
`
|
||||
|
||||
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
||||
|
||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
|
||||
|
||||
type userRoomKeysStatements struct {
|
||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||
insertUserRoomPublicKeyStmt *sql.Stmt
|
||||
selectUserRoomKeyStmt *sql.Stmt
|
||||
selectUserNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func CreateUserRoomKeysTable(db *sql.DB) error {
|
||||
_, err := db.Exec(userRoomKeysSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
||||
s := &userRoomKeysStatements{}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL},
|
||||
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt)
|
||||
err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt)
|
||||
err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
userNID types.EventStateKeyNID,
|
||||
roomNID types.RoomNID,
|
||||
) (ed25519.PrivateKey, error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt)
|
||||
var result ed25519.PrivateKey
|
||||
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt)
|
||||
|
||||
roomNIDs := make([]types.RoomNID, 0, len(senderKeys))
|
||||
var senders [][]byte
|
||||
for roomNID := range senderKeys {
|
||||
roomNIDs = append(roomNIDs, roomNID)
|
||||
for _, key := range senderKeys[roomNID] {
|
||||
senders = append(senders, key)
|
||||
}
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(senders))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
|
||||
|
||||
result := make(map[string]types.UserRoomKeyPair, len(senders)+len(roomNIDs))
|
||||
var publicKey []byte
|
||||
userRoomKeyPair := types.UserRoomKeyPair{}
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[string(publicKey)] = userRoomKeyPair
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
|
@ -2,14 +2,18 @@ package shared
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
|
|
@ -41,6 +45,7 @@ type Database struct {
|
|||
MembershipTable tables.Membership
|
||||
PublishedTable tables.Published
|
||||
Purge tables.Purge
|
||||
UserRoomKeyTable tables.UserRoomKeys
|
||||
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
|
||||
}
|
||||
|
||||
|
|
@ -485,10 +490,10 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) {
|
||||
func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) {
|
||||
var requestSenderUserNID types.EventStateKeyNID
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID)
|
||||
requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID))
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -931,6 +936,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) (
|
|||
return roomVersion, err
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
// MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec:
|
||||
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
|
||||
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
|
||||
|
|
@ -1009,7 +1015,7 @@ func (d *EventDatabase) MaybeRedactEvent(
|
|||
switch {
|
||||
case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact:
|
||||
// 1. The power level of the redaction event’s sender is greater than or equal to the redact level.
|
||||
case sender1Domain == sender2Domain:
|
||||
case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain:
|
||||
// 2. The domain of the redaction event’s sender matches that of the original event’s sender.
|
||||
default:
|
||||
ignoreRedaction = true
|
||||
|
|
@ -1609,6 +1615,147 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS
|
|||
})
|
||||
}
|
||||
|
||||
// InsertUserRoomPrivatePublicKey inserts a new user room key for the given user and room.
|
||||
// Returns the newly inserted private key or an existing private key. If there is
|
||||
// an error talking to the database, returns that error.
|
||||
func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
|
||||
uID := userID.String()
|
||||
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
|
||||
if sErr != nil {
|
||||
return nil, sErr
|
||||
}
|
||||
stateKeyNID := stateKeyNIDMap[uID]
|
||||
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
|
||||
if rErr != nil {
|
||||
return rErr
|
||||
}
|
||||
if roomInfo == nil {
|
||||
return eventutil.ErrRoomNoExists{}
|
||||
}
|
||||
|
||||
var iErr error
|
||||
result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivatePublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key)
|
||||
return iErr
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
// InsertUserRoomPublicKey inserts a new user room key for the given user and room.
|
||||
// Returns the newly inserted public key or an existing public key. If there is
|
||||
// an error talking to the database, returns that error.
|
||||
func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) {
|
||||
uID := userID.String()
|
||||
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
|
||||
if sErr != nil {
|
||||
return nil, sErr
|
||||
}
|
||||
stateKeyNID := stateKeyNIDMap[uID]
|
||||
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
|
||||
if rErr != nil {
|
||||
return rErr
|
||||
}
|
||||
if roomInfo == nil {
|
||||
return eventutil.ErrRoomNoExists{}
|
||||
}
|
||||
|
||||
var iErr error
|
||||
result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key)
|
||||
return iErr
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
// SelectUserRoomPrivateKey queries the users room private key.
|
||||
// If no key exists, returns no key and no error. Otherwise returns
|
||||
// the key and a database error, if any.
|
||||
func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) {
|
||||
uID := userID.String()
|
||||
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
|
||||
if sErr != nil {
|
||||
return nil, sErr
|
||||
}
|
||||
stateKeyNID := stateKeyNIDMap[uID]
|
||||
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
|
||||
if rErr != nil {
|
||||
return rErr
|
||||
}
|
||||
if roomInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID)
|
||||
if !errors.Is(sErr, sql.ErrNoRows) {
|
||||
return sErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
|
||||
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
|
||||
result = make(map[spec.RoomID]map[string]string, len(publicKeys))
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
|
||||
// map all roomIDs to roomNIDs
|
||||
query := make(map[types.RoomNID][]ed25519.PublicKey)
|
||||
rooms := make(map[types.RoomNID]spec.RoomID)
|
||||
for roomID, keys := range publicKeys {
|
||||
roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String())
|
||||
if !ok {
|
||||
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
|
||||
if rErr != nil {
|
||||
return rErr
|
||||
}
|
||||
if roomInfo == nil {
|
||||
logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String())
|
||||
continue
|
||||
}
|
||||
roomNID = roomInfo.RoomNID
|
||||
}
|
||||
|
||||
query[roomNID] = keys
|
||||
rooms[roomNID] = roomID
|
||||
}
|
||||
|
||||
// get the user room key pars
|
||||
userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, query)
|
||||
if sErr != nil {
|
||||
return sErr
|
||||
}
|
||||
nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap))
|
||||
for _, nid := range userRoomKeyPairMap {
|
||||
nids = append(nids, nid.EventStateKeyNID)
|
||||
}
|
||||
// get the userIDs
|
||||
nidMap, seErr := d.EventStateKeys(ctx, nids)
|
||||
if seErr != nil {
|
||||
return seErr
|
||||
}
|
||||
|
||||
// build the result map (roomID -> map publicKey -> userID)
|
||||
for publicKey, userRoomKeyPair := range userRoomKeyPairMap {
|
||||
userID := nidMap[userRoomKeyPair.EventStateKeyNID]
|
||||
roomID := rooms[userRoomKeyPair.RoomNID]
|
||||
resMap, exists := result[roomID]
|
||||
if !exists {
|
||||
resMap = map[string]string{}
|
||||
}
|
||||
resMap[publicKey] = userID
|
||||
result[roomID] = resMap
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
|
||||
// it should live in this package!
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,15 @@ package shared_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
ed255192 "golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
|
|
@ -23,41 +27,62 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
|
|||
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
|
||||
dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}
|
||||
|
||||
db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter())
|
||||
writer := sqlutil.NewExclusiveWriter()
|
||||
db, err := sqlutil.Open(dbOpts, writer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var membershipTable tables.Membership
|
||||
var stateKeyTable tables.EventStateKeys
|
||||
var userRoomKeys tables.UserRoomKeys
|
||||
var roomsTable tables.Rooms
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
err = postgres.CreateRoomsTable(db)
|
||||
assert.NoError(t, err)
|
||||
err = postgres.CreateEventStateKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
err = postgres.CreateMembershipTable(db)
|
||||
assert.NoError(t, err)
|
||||
err = postgres.CreateUserRoomKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
roomsTable, err = postgres.PrepareRoomsTable(db)
|
||||
assert.NoError(t, err)
|
||||
membershipTable, err = postgres.PrepareMembershipTable(db)
|
||||
assert.NoError(t, err)
|
||||
stateKeyTable, err = postgres.PrepareEventStateKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
userRoomKeys, err = postgres.PrepareUserRoomKeysTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
err = sqlite3.CreateRoomsTable(db)
|
||||
assert.NoError(t, err)
|
||||
err = sqlite3.CreateEventStateKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
err = sqlite3.CreateMembershipTable(db)
|
||||
assert.NoError(t, err)
|
||||
err = sqlite3.CreateUserRoomKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
roomsTable, err = sqlite3.PrepareRoomsTable(db)
|
||||
assert.NoError(t, err)
|
||||
membershipTable, err = sqlite3.PrepareMembershipTable(db)
|
||||
assert.NoError(t, err)
|
||||
stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
userRoomKeys, err = sqlite3.PrepareUserRoomKeysTable(db)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
|
||||
|
||||
evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache}
|
||||
evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache, Writer: writer}
|
||||
|
||||
return &shared.Database{
|
||||
DB: db,
|
||||
EventDatabase: evDb,
|
||||
MembershipTable: membershipTable,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
Cache: cache,
|
||||
DB: db,
|
||||
EventDatabase: evDb,
|
||||
MembershipTable: membershipTable,
|
||||
UserRoomKeyTable: userRoomKeys,
|
||||
RoomsTable: roomsTable,
|
||||
Writer: writer,
|
||||
Cache: cache,
|
||||
}, func() {
|
||||
clearDB()
|
||||
err = db.Close()
|
||||
|
|
@ -97,3 +122,80 @@ func Test_GetLeftUsers(t *testing.T) {
|
|||
assert.ElementsMatch(t, expectedUserIDs, leftUsers)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserRoomKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
|
||||
userID, err := spec.NewUserID(alice.ID, true)
|
||||
assert.NoError(t, err)
|
||||
roomID, err := spec.NewRoomID(room.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateRoomserverDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create a room NID so we can query the room
|
||||
_, err = db.RoomsTable.InsertRoomNID(ctx, nil, roomID.String(), gomatrixserverlib.RoomVersionV10)
|
||||
assert.NoError(t, err)
|
||||
doesNotExist, err := spec.NewRoomID("!doesnotexist:localhost")
|
||||
assert.NoError(t, err)
|
||||
_, err = db.RoomsTable.InsertRoomNID(ctx, nil, doesNotExist.String(), gomatrixserverlib.RoomVersionV10)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, key, err := ed25519.GenerateKey(nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
gotKey, err := db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, gotKey, key)
|
||||
|
||||
// again, this shouldn't result in an error, but return the existing key
|
||||
_, key2, err := ed25519.GenerateKey(nil)
|
||||
assert.NoError(t, err)
|
||||
gotKey, err = db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, gotKey, key)
|
||||
|
||||
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key, gotKey)
|
||||
|
||||
// Key doesn't exist, we shouldn't get anything back
|
||||
assert.NoError(t, err)
|
||||
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, gotKey)
|
||||
|
||||
queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{
|
||||
*roomID: {key.Public().(ed25519.PublicKey)},
|
||||
}
|
||||
|
||||
userIDs, err := db.SelectUserIDsForPublicKeys(ctx, queryUserIDs)
|
||||
assert.NoError(t, err)
|
||||
wantKeys := map[spec.RoomID]map[string]string{
|
||||
*roomID: {
|
||||
string(key.Public().(ed25519.PublicKey)): userID.String(),
|
||||
},
|
||||
}
|
||||
assert.Equal(t, wantKeys, userIDs)
|
||||
|
||||
// insert key that came in over federation
|
||||
var gotPublicKey, key4 ed255192.PublicKey
|
||||
key4, _, err = ed25519.GenerateKey(nil)
|
||||
assert.NoError(t, err)
|
||||
gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *doesNotExist, key4)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key4, gotPublicKey)
|
||||
|
||||
// test invalid room
|
||||
reallyDoesNotExist, err := spec.NewRoomID("!reallydoesnotexist:localhost")
|
||||
assert.NoError(t, err)
|
||||
_, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4)
|
||||
assert.Error(t, err)
|
||||
_, err = db.InsertUserRoomPrivatePublicKey(context.Background(), *userID, *reallyDoesNotExist, key)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -138,6 +138,9 @@ func (d *Database) create(db *sql.DB) error {
|
|||
if err := CreateRedactionsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := CreateUserRoomKeysTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -199,6 +202,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userRoomKeys, err := PrepareUserRoomKeysTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Database = shared.Database{
|
||||
DB: db,
|
||||
|
|
@ -224,6 +231,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
PublishedTable: published,
|
||||
GetRoomUpdaterFn: d.GetRoomUpdater,
|
||||
Purge: purge,
|
||||
UserRoomKeyTable: userRoomKeys,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
146
roomserver/storage/sqlite3/user_room_keys_table.go
Normal file
146
roomserver/storage/sqlite3/user_room_keys_table.go
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const userRoomKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_user_room_keys (
|
||||
user_nid INTEGER NOT NULL,
|
||||
room_nid INTEGER NOT NULL,
|
||||
pseudo_id_key TEXT NULL, -- may be null for users not local to the server
|
||||
pseudo_id_pub_key TEXT NOT NULL,
|
||||
CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid)
|
||||
);
|
||||
`
|
||||
|
||||
const insertUserRoomKeySQL = `
|
||||
INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key
|
||||
RETURNING (pseudo_id_key)
|
||||
`
|
||||
|
||||
const insertUserRoomPublicKeySQL = `
|
||||
INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3)
|
||||
ON CONFLICT DO UPDATE SET pseudo_id_pub_key = $3
|
||||
RETURNING (pseudo_id_pub_key)
|
||||
`
|
||||
|
||||
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
||||
|
||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
|
||||
|
||||
type userRoomKeysStatements struct {
|
||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||
insertUserRoomPublicKeyStmt *sql.Stmt
|
||||
selectUserRoomKeyStmt *sql.Stmt
|
||||
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
|
||||
}
|
||||
|
||||
func CreateUserRoomKeysTable(db *sql.DB) error {
|
||||
_, err := db.Exec(userRoomKeysSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
||||
s := &userRoomKeysStatements{}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
|
||||
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt)
|
||||
err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt)
|
||||
err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
userNID types.EventStateKeyNID,
|
||||
roomNID types.RoomNID,
|
||||
) (ed25519.PrivateKey, error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt)
|
||||
var result ed25519.PrivateKey
|
||||
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
|
||||
|
||||
roomNIDs := make([]any, 0, len(senderKeys))
|
||||
var senders []any
|
||||
for roomNID := range senderKeys {
|
||||
roomNIDs = append(roomNIDs, roomNID)
|
||||
|
||||
for _, key := range senderKeys[roomNID] {
|
||||
senders = append(senders, []byte(key))
|
||||
}
|
||||
}
|
||||
|
||||
selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1)
|
||||
selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs
|
||||
|
||||
selectStmt, err := txn.Prepare(selectSQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := append(roomNIDs, senders...)
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, selectStmt)
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement")
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
|
||||
|
||||
result := make(map[string]types.UserRoomKeyPair, len(params))
|
||||
var publicKey []byte
|
||||
userRoomKeyPair := types.UserRoomKeyPair{}
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[string(publicKey)] = userRoomKeyPair
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package tables
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
|
|
@ -184,6 +185,19 @@ type Purge interface {
|
|||
) error
|
||||
}
|
||||
|
||||
type UserRoomKeys interface {
|
||||
// InsertUserRoomPrivatePublicKey inserts the given private key as well as the public key for it. This should be used
|
||||
// when creating keys locally.
|
||||
InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error)
|
||||
// InsertUserRoomPublicKey inserts the given public key, this should be used for users NOT local to this server
|
||||
InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error)
|
||||
// SelectUserRoomPrivateKey selects the private key for the given user and room combination
|
||||
SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
|
||||
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
|
||||
// If a senderKey can't be found, it is omitted in the result.
|
||||
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
|
||||
}
|
||||
|
||||
// StrippedEvent represents a stripped event for returning extracted content values.
|
||||
type StrippedEvent struct {
|
||||
RoomID string
|
||||
|
|
|
|||
115
roomserver/storage/tables/user_room_keys_table_test.go
Normal file
115
roomserver/storage/tables/user_room_keys_table_test.go
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
ed255192 "golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, db *sql.DB, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
err = postgres.CreateUserRoomKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = postgres.PrepareUserRoomKeysTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
err = sqlite3.CreateUserRoomKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = sqlite3.PrepareUserRoomKeysTable(db)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tab, db, close
|
||||
}
|
||||
|
||||
func TestUserRoomKeysTable(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := mustCreateUserRoomKeysTable(t, dbType)
|
||||
defer close()
|
||||
userNID := types.EventStateKeyNID(1)
|
||||
roomNID := types.RoomNID(1)
|
||||
_, key, err := ed25519.GenerateKey(nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||
var gotKey, key2, key3 ed25519.PrivateKey
|
||||
gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, gotKey, key)
|
||||
|
||||
// again, this shouldn't result in an error, but return the existing key
|
||||
_, key2, err = ed25519.GenerateKey(nil)
|
||||
assert.NoError(t, err)
|
||||
gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, gotKey, key)
|
||||
|
||||
// add another user
|
||||
_, key3, err = ed25519.GenerateKey(nil)
|
||||
assert.NoError(t, err)
|
||||
userNID2 := types.EventStateKeyNID(2)
|
||||
_, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID2, roomNID, key3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key, gotKey)
|
||||
|
||||
// try to update an existing key, this should only be done for users NOT on this homeserver
|
||||
var gotPubKey ed25519.PublicKey
|
||||
gotPubKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, roomNID, key2.Public().(ed25519.PublicKey))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key2.Public(), gotPubKey)
|
||||
|
||||
// Key doesn't exist
|
||||
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, gotKey)
|
||||
|
||||
// query user NIDs for senderKeys
|
||||
var gotKeys map[string]types.UserRoomKeyPair
|
||||
query := map[types.RoomNID][]ed25519.PublicKey{
|
||||
roomNID: {key2.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)},
|
||||
types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist
|
||||
}
|
||||
gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, gotKeys)
|
||||
|
||||
wantKeys := map[string]types.UserRoomKeyPair{
|
||||
string(key2.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID},
|
||||
string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2},
|
||||
}
|
||||
assert.Equal(t, wantKeys, gotKeys)
|
||||
|
||||
// insert key that came in over federation
|
||||
var gotPublicKey, key4 ed255192.PublicKey
|
||||
key4, _, err = ed25519.GenerateKey(nil)
|
||||
assert.NoError(t, err)
|
||||
gotPublicKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, 2, key4)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key4, gotPublicKey)
|
||||
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
})
|
||||
}
|
||||
|
|
@ -44,6 +44,11 @@ type EventMetadata struct {
|
|||
RoomNID RoomNID
|
||||
}
|
||||
|
||||
type UserRoomKeyPair struct {
|
||||
RoomNID RoomNID
|
||||
EventStateKeyNID EventStateKeyNID
|
||||
}
|
||||
|
||||
// StateSnapshotNID is a numeric ID for the state at an event.
|
||||
type StateSnapshotNID int64
|
||||
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ type reqCtx struct {
|
|||
rsAPI roomserver.RoomserverInternalAPI
|
||||
db Database
|
||||
req *EventRelationshipRequest
|
||||
userID string
|
||||
userID spec.UserID
|
||||
roomVersion gomatrixserverlib.RoomVersion
|
||||
|
||||
// federated request args
|
||||
|
|
@ -173,10 +173,17 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
|
|||
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
|
||||
}
|
||||
}
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
|
||||
}
|
||||
}
|
||||
rc := reqCtx{
|
||||
ctx: req.Context(),
|
||||
req: relation,
|
||||
userID: device.UserID,
|
||||
userID: *userID,
|
||||
rsAPI: rsAPI,
|
||||
fsAPI: fsAPI,
|
||||
isFederatedRequest: false,
|
||||
|
|
|
|||
|
|
@ -529,6 +529,10 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str
|
|||
return spec.NewUserID(string(senderID), true)
|
||||
}
|
||||
|
||||
func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
|
||||
return spec.SenderID(userID.String()), nil
|
||||
}
|
||||
|
||||
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
|
||||
for _, eventID := range req.EventIDs {
|
||||
ev := r.events[eventID]
|
||||
|
|
@ -540,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver
|
|||
}
|
||||
|
||||
func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error {
|
||||
rooms := r.userToJoinedRooms[req.UserID]
|
||||
rooms := r.userToJoinedRooms[req.UserID.String()]
|
||||
for _, roomID := range rooms {
|
||||
if roomID == req.RoomID {
|
||||
res.IsInRoom = true
|
||||
|
|
|
|||
|
|
@ -373,7 +373,15 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
|
|||
// TODO: check that it's a join and not a profile change (means unmarshalling prev_content)
|
||||
if membership == spec.Join {
|
||||
// check it's a local join
|
||||
if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil {
|
||||
if ev.StateKey() == nil {
|
||||
return sp, fmt.Errorf("unexpected nil state_key")
|
||||
}
|
||||
|
||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
|
||||
if err != nil || userID == nil {
|
||||
return sp, fmt.Errorf("failed getting userID for sender: %w", err)
|
||||
}
|
||||
if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
|
||||
return sp, nil
|
||||
}
|
||||
|
||||
|
|
@ -395,9 +403,15 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
|
|||
if msg.Event.StateKey() == nil {
|
||||
return
|
||||
}
|
||||
if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil {
|
||||
|
||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
|
||||
if err != nil || userID == nil {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
|
||||
return
|
||||
}
|
||||
|
||||
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
|
|
@ -440,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
|
|||
|
||||
// Notify any active sync requests that the invite has been retired.
|
||||
s.inviteStream.Advance(pduPos)
|
||||
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
|
||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
|
||||
if err != nil || userID == nil {
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": msg.EventID,
|
||||
"sender_id": msg.TargetSenderID,
|
||||
log.ErrorKey: err,
|
||||
}).Errorf("failed to find userID for sender")
|
||||
return
|
||||
}
|
||||
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String())
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) onNewPeek(
|
||||
|
|
|
|||
|
|
@ -134,9 +134,17 @@ func ApplyHistoryVisibilityFilter(
|
|||
}
|
||||
}
|
||||
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
|
||||
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) {
|
||||
eventsFiltered = append(eventsFiltered, ev)
|
||||
continue
|
||||
|
||||
user, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user)
|
||||
if err == nil {
|
||||
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
|
||||
eventsFiltered = append(eventsFiltered, ev)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Always allow history evVis events on boundaries. This is done
|
||||
// by setting the effective evVis to the least restrictive
|
||||
|
|
|
|||
|
|
@ -169,12 +169,16 @@ func TrackChangedUsers(
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, state := range stateRes.Rooms {
|
||||
for roomID, state := range stateRes.Rooms {
|
||||
for tuple, membership := range state {
|
||||
if membership != spec.Join {
|
||||
continue
|
||||
}
|
||||
queryRes.UserIDsToCount[tuple.StateKey]--
|
||||
user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
|
||||
if queryErr != nil || user == nil {
|
||||
continue
|
||||
}
|
||||
queryRes.UserIDsToCount[user.String()]--
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -211,14 +215,18 @@ func TrackChangedUsers(
|
|||
if err != nil {
|
||||
return nil, left, err
|
||||
}
|
||||
for _, state := range stateRes.Rooms {
|
||||
for roomID, state := range stateRes.Rooms {
|
||||
for tuple, membership := range state {
|
||||
if membership != spec.Join {
|
||||
continue
|
||||
}
|
||||
// new user who we weren't previously sharing rooms with
|
||||
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
|
||||
changed = append(changed, tuple.StateKey) // changed is returned
|
||||
user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
|
||||
if err != nil || user == nil {
|
||||
continue
|
||||
}
|
||||
changed = append(changed, user.String()) // changed is returned
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,6 +64,10 @@ type mockRoomserverAPI struct {
|
|||
roomIDToJoinedMembers map[string][]string
|
||||
}
|
||||
|
||||
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return spec.NewUserID(string(senderID), true)
|
||||
}
|
||||
|
||||
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
|
||||
func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
rstypes "github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
|
|
@ -36,7 +37,8 @@ import (
|
|||
// the event, but the token has already advanced by the time they fetch it, resulting
|
||||
// in missed events.
|
||||
type Notifier struct {
|
||||
lock *sync.RWMutex
|
||||
lock *sync.RWMutex
|
||||
rsAPI api.SyncRoomserverAPI
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
roomIDToJoinedUsers map[string]*userIDSet
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
|
|
@ -55,8 +57,9 @@ type Notifier struct {
|
|||
// NewNotifier creates a new notifier set to the given sync position.
|
||||
// In order for this to be of any use, the Notifier needs to be told all rooms and
|
||||
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
|
||||
func NewNotifier() *Notifier {
|
||||
func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier {
|
||||
return &Notifier{
|
||||
rsAPI: rsAPI,
|
||||
roomIDToJoinedUsers: make(map[string]*userIDSet),
|
||||
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
|
||||
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
|
||||
|
|
@ -104,26 +107,32 @@ func (n *Notifier) OnNewEvent(
|
|||
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
|
||||
// If this is an invite, also add in the invitee to this list.
|
||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||
targetUserID := *ev.StateKey()
|
||||
membership, err := ev.Membership()
|
||||
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
|
||||
"Notifier.OnNewEvent: Failed to unmarshal member event",
|
||||
"Notifier.OnNewEvent: Failed to find the userID for this event",
|
||||
)
|
||||
} else {
|
||||
// Keep the joined user map up-to-date
|
||||
switch membership {
|
||||
case spec.Invite:
|
||||
usersToNotify = append(usersToNotify, targetUserID)
|
||||
case spec.Join:
|
||||
// Manually append the new user's ID so they get notified
|
||||
// along all members in the room
|
||||
usersToNotify = append(usersToNotify, targetUserID)
|
||||
n._addJoinedUser(ev.RoomID(), targetUserID)
|
||||
case spec.Leave:
|
||||
fallthrough
|
||||
case spec.Ban:
|
||||
n._removeJoinedUser(ev.RoomID(), targetUserID)
|
||||
membership, err := ev.Membership()
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
|
||||
"Notifier.OnNewEvent: Failed to unmarshal member event",
|
||||
)
|
||||
} else {
|
||||
// Keep the joined user map up-to-date
|
||||
switch membership {
|
||||
case spec.Invite:
|
||||
usersToNotify = append(usersToNotify, targetUserID.String())
|
||||
case spec.Join:
|
||||
// Manually append the new user's ID so they get notified
|
||||
// along all members in the room
|
||||
usersToNotify = append(usersToNotify, targetUserID.String())
|
||||
n._addJoinedUser(ev.RoomID(), targetUserID.String())
|
||||
case spec.Leave:
|
||||
fallthrough
|
||||
case spec.Ban:
|
||||
n._removeJoinedUser(ev.RoomID(), targetUserID.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,9 +22,11 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
rstypes "github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
|
|
@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
|
|||
}
|
||||
}
|
||||
|
||||
type TestRoomServer struct{ api.SyncRoomserverAPI }
|
||||
|
||||
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return spec.NewUserID(string(senderID), true)
|
||||
}
|
||||
|
||||
// Test that the current position is returned if a request is already behind.
|
||||
func TestImmediateNotification(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
|
||||
if err != nil {
|
||||
|
|
@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) {
|
|||
|
||||
// Test that new events to a joined room unblocks the request.
|
||||
func TestNewEventAndJoinedToRoom(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
n.setUsersJoinedToRooms(map[string][]string{
|
||||
roomID: {alice, bob},
|
||||
|
|
@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCorrectStream(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
stream := lockedFetchUserStream(n, bob, bobDev)
|
||||
if stream.UserID != bob {
|
||||
|
|
@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCorrectStreamWakeup(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
awoken := make(chan string)
|
||||
|
||||
|
|
@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
|
|||
|
||||
// Test that an invite unblocks the request
|
||||
func TestNewInviteEventForUser(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
n.setUsersJoinedToRooms(map[string][]string{
|
||||
roomID: {alice, bob},
|
||||
|
|
@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) {
|
|||
|
||||
// Test that all blocked requests get woken up on a new event.
|
||||
func TestMultipleRequestWakeup(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
n.setUsersJoinedToRooms(map[string][]string{
|
||||
roomID: {alice, bob},
|
||||
|
|
@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
|
|||
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
|
||||
// listen as bob. Make bob leave room. Make alice send event to room.
|
||||
// Make sure alice gets woken up only and not bob as well.
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
n.setUsersJoinedToRooms(map[string][]string{
|
||||
roomID: {alice, bob},
|
||||
|
|
|
|||
|
|
@ -85,9 +85,16 @@ func Context(
|
|||
*filter.Rooms = append(*filter.Rooms, roomID)
|
||||
}
|
||||
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam("Device UserID is invalid"),
|
||||
}
|
||||
}
|
||||
ctx := req.Context()
|
||||
membershipRes := roomserver.QueryMembershipForUserResponse{}
|
||||
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID}
|
||||
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID}
|
||||
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
|
||||
logrus.WithError(err).Error("unable to query membership")
|
||||
return util.JSONResponse{
|
||||
|
|
@ -217,12 +224,9 @@ func Context(
|
|||
}
|
||||
}
|
||||
|
||||
sender := spec.UserID{}
|
||||
userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
|
||||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
|
||||
ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}, requestedEvent)
|
||||
response := ContextRespsonse{
|
||||
Event: &ev,
|
||||
EventsAfter: eventsAfterClient,
|
||||
|
|
|
|||
|
|
@ -106,8 +106,17 @@ func GetEvent(
|
|||
if err == nil && senderUserID != nil {
|
||||
sender = *senderUserID
|
||||
}
|
||||
|
||||
sk := events[0].StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
|
||||
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,14 +59,21 @@ func GetMemberships(
|
|||
syncDB storage.Database, rsAPI api.SyncRoomserverAPI,
|
||||
joinedOnly bool, membership, notMembership *string, at string,
|
||||
) util.JSONResponse {
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam("Device UserID is invalid"),
|
||||
}
|
||||
}
|
||||
queryReq := api.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: device.UserID,
|
||||
UserID: *userID,
|
||||
}
|
||||
|
||||
var queryRes api.QueryMembershipForUserResponse
|
||||
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed")
|
||||
if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
|
||||
util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
|
|
|
|||
|
|
@ -296,9 +296,13 @@ func OnIncomingMessagesRequest(
|
|||
}
|
||||
|
||||
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
|
||||
fullUserID, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
req := api.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: userID,
|
||||
UserID: *fullUserID,
|
||||
}
|
||||
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
|
||||
return api.QueryMembershipForUserResponse{}, err
|
||||
|
|
|
|||
|
|
@ -119,9 +119,18 @@ func Relations(
|
|||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
|
||||
sk := event.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
res.Chunk = append(
|
||||
res.Chunk,
|
||||
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
|
||||
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
|||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
|
||||
sk := event.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
results = append(results, Result{
|
||||
Context: SearchContextResponse{
|
||||
Start: startToken.String(),
|
||||
|
|
@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
|||
ProfileInfo: profileInfos,
|
||||
},
|
||||
Rank: eventScore[event.EventID()].Score,
|
||||
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
|
||||
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
|
||||
})
|
||||
roomGroup := groups[event.RoomID()]
|
||||
roomGroup.Results = append(roomGroup.Results, event.EventID())
|
||||
|
|
|
|||
|
|
@ -507,8 +507,20 @@ func (d *Database) CleanSendToDeviceUpdates(
|
|||
|
||||
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
|
||||
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
|
||||
func getMembershipFromEvent(ev gomatrixserverlib.PDU, userID string) (string, string) {
|
||||
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
|
||||
func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) {
|
||||
if ev.StateKey() == nil || *ev.StateKey() == "" {
|
||||
return "", ""
|
||||
}
|
||||
fullUser, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) {
|
||||
return "", ""
|
||||
}
|
||||
membership, err := ev.Membership()
|
||||
|
|
|
|||
|
|
@ -430,7 +430,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
|
|||
for _, ev := range stateStreamEvents {
|
||||
// Look for our membership in the state events and skip over any
|
||||
// membership events that are not related to us.
|
||||
membership, prevMembership := getMembershipFromEvent(ev.PDU, userID)
|
||||
membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI)
|
||||
if membership == "" {
|
||||
continue
|
||||
}
|
||||
|
|
@ -556,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
|
|||
|
||||
for roomID, stateStreamEvents := range state {
|
||||
for _, ev := range stateStreamEvents {
|
||||
if membership, _ := getMembershipFromEvent(ev.PDU, userID); membership != "" {
|
||||
if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" {
|
||||
if membership != spec.Join { // We've already added full state for all joined rooms above.
|
||||
deltas[roomID] = types.StateDelta{
|
||||
Membership: membership,
|
||||
|
|
|
|||
|
|
@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync(
|
|||
user = *sender
|
||||
}
|
||||
|
||||
sk := inviteEvent.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
|
||||
// skip ignored user events
|
||||
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
|
||||
continue
|
||||
}
|
||||
ir := types.NewInviteResponse(inviteEvent, user)
|
||||
ir := types.NewInviteResponse(inviteEvent, user, sk)
|
||||
req.Response.Rooms.Invite[roomID] = ir
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -605,13 +605,17 @@ func (p *PDUStreamProvider) lazyLoadMembers(
|
|||
// If this is a gapped incremental sync, we still want this membership
|
||||
isGappedIncremental := limited && incremental
|
||||
// We want this users membership event, keep it in the list
|
||||
stateKey := *event.StateKey()
|
||||
if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID {
|
||||
userID := ""
|
||||
stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey()))
|
||||
if err == nil && stateKeyUserID != nil {
|
||||
userID = stateKeyUserID.String()
|
||||
}
|
||||
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {
|
||||
newStateEvents = append(newStateEvents, event)
|
||||
if !stateFilter.IncludeRedundantMembers {
|
||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID())
|
||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID())
|
||||
}
|
||||
delete(timelineUsers, stateKey)
|
||||
delete(timelineUsers, userID)
|
||||
}
|
||||
} else {
|
||||
newStateEvents = append(newStateEvents, event)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ func AddPublicRoutes(
|
|||
}
|
||||
|
||||
eduCache := caching.NewTypingCache()
|
||||
notifier := notifier.NewNotifier()
|
||||
notifier := notifier.NewNotifier(rsAPI)
|
||||
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier)
|
||||
notifier.SetCurrentPosition(streams.Latest(context.Background()))
|
||||
if err = notifier.Load(context.Background(), syncDB); err != nil {
|
||||
|
|
|
|||
|
|
@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
|
|||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
evs = append(evs, ToClientEvent(se, format, sender))
|
||||
|
||||
sk := se.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
evs = append(evs, ToClientEvent(se, format, sender, sk))
|
||||
}
|
||||
return evs
|
||||
}
|
||||
|
||||
// ToClientEvent converts a single server event to a client event.
|
||||
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
|
||||
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent {
|
||||
ce := ClientEvent{
|
||||
Content: spec.RawJSON(se.Content()),
|
||||
Sender: sender.String(),
|
||||
Type: se.Type(),
|
||||
StateKey: se.StateKey(),
|
||||
StateKey: stateKey,
|
||||
Unsigned: spec.RawJSON(se.Unsigned()),
|
||||
OriginServerTS: se.OriginServerTS(),
|
||||
EventID: se.EventID(),
|
||||
|
|
@ -77,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
|
|||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
// ToClientEvent converts a single server event to a client event.
|
||||
// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
|
||||
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
|
||||
sender := spec.UserID{}
|
||||
userID, err := userIDQuery(event.RoomID(), event.SenderID())
|
||||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
|
||||
sk := event.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
return ToClientEvent(event, FormatAll, sender, sk)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
|
|||
if err != nil {
|
||||
t.Fatalf("failed to create userID: %s", err)
|
||||
}
|
||||
ce := ToClientEvent(ev, FormatAll, *userID)
|
||||
sk := ""
|
||||
ce := ToClientEvent(ev, FormatAll, *userID, &sk)
|
||||
if ce.EventID != ev.EventID() {
|
||||
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
|
||||
}
|
||||
|
|
@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to create userID: %s", err)
|
||||
}
|
||||
ce := ToClientEvent(ev, FormatSync, *userID)
|
||||
sk := ""
|
||||
ce := ToClientEvent(ev, FormatSync, *userID, &sk)
|
||||
if ce.RoomID != "" {
|
||||
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -539,7 +539,7 @@ type InviteResponse struct {
|
|||
}
|
||||
|
||||
// NewInviteResponse creates an empty response with initialised arrays.
|
||||
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
|
||||
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse {
|
||||
res := InviteResponse{}
|
||||
res.InviteState.Events = []json.RawMessage{}
|
||||
|
||||
|
|
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe
|
|||
|
||||
// Then we'll see if we can create a partial of the invite event itself.
|
||||
// This is needed for clients to work out *who* sent the invite.
|
||||
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
|
||||
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey)
|
||||
inviteEvent.Unsigned = nil
|
||||
if ev, err := json.Marshal(inviteEvent); err == nil {
|
||||
res.InviteState.Events = append(res.InviteState.Events, ev)
|
||||
|
|
|
|||
|
|
@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
skString := skUserID.String()
|
||||
sk := &skString
|
||||
|
||||
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
|
||||
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk)
|
||||
j, err := json.Marshal(res)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
|||
|
|
@ -306,7 +306,16 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
|
|||
if queryErr == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
|
||||
|
||||
sk := event.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if queryErr == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk)
|
||||
var member *localMembership
|
||||
member, err = newLocalMembership(&cevent)
|
||||
if err != nil {
|
||||
|
|
@ -539,12 +548,21 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
|
|||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
|
||||
sk := event.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if queryErr == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
n := &api.Notification{
|
||||
Actions: actions,
|
||||
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
|
||||
// fields seem to match. room_id should be missing, which
|
||||
// matches the behaviour of FormatSync.
|
||||
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
|
||||
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk),
|
||||
// TODO: this is per-device, but it's not part of the primary
|
||||
// key. So inserting one notification per profile tag doesn't
|
||||
// make sense. What is this supposed to be? Sytests require it
|
||||
|
|
@ -792,10 +810,20 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
|
|||
Type: event.Type(),
|
||||
},
|
||||
}
|
||||
if mem, err := event.Membership(); err == nil {
|
||||
if mem, memberErr := event.Membership(); memberErr == nil {
|
||||
req.Notification.Membership = mem
|
||||
}
|
||||
if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) {
|
||||
userID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName), true)
|
||||
if err != nil {
|
||||
logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart)
|
||||
return nil, err
|
||||
}
|
||||
localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID)
|
||||
if err != nil {
|
||||
logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID())
|
||||
return nil, err
|
||||
}
|
||||
if event.StateKey() != nil && *event.StateKey() == string(localSender) {
|
||||
req.Notification.UserIsTarget = true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,8 +104,9 @@ func TestNotifyUserCountsAsync(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
sk := ""
|
||||
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
|
||||
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
|
||||
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk),
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue