Use a MembershipQuerier to get the current membership

This commit is contained in:
Till Faelligen 2023-07-05 10:57:41 +02:00
parent 3085928906
commit 9ddd62925c
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
10 changed files with 36 additions and 31 deletions

View file

@ -73,6 +73,7 @@ type RoomserverFederationAPI interface {
GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error)
GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error)
CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
} }
type P2PFederationAPI interface { type P2PFederationAPI interface {

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -182,3 +183,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted(
} }
return request() return request()
} }
func (a *FederationInternalAPI) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
return a.rsAPI.CurrentMembership(ctx, roomID, senderID)
}

View file

@ -167,7 +167,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
KeyRing: r.keyRing, KeyRing: r.keyRing,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }, r.rsAPI),
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
@ -190,6 +190,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
} }
return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID) return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
}, },
MembershipQuerier: r.rsAPI,
} }
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
@ -387,7 +388,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
} }
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider, ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider, r.rsAPI), userIDProvider, r.rsAPI,
) )
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
@ -676,7 +677,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error {
func federatedEventProvider( func federatedEventProvider(
ctx context.Context, federation fclient.FederationClient, ctx context.Context, federation fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
userIDForSender spec.UserIDForSender, userIDForSender spec.UserIDForSender, rsAPI gomatrixserverlib.MembershipQuerier,
) gomatrixserverlib.EventProvider { ) gomatrixserverlib.EventProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
@ -726,7 +727,7 @@ func federatedEventProvider(
} }
// Check the signatures of the event. // Check the signatures of the event.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender, rsAPI); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
} }

View file

@ -168,24 +168,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
continue continue
} }
// If the user is already joined and we receive a new "join" event, we're adding the previous
// content to unsigned, this way VerifyEventSignatures skips the mxid_mapping check
// FIXME: this is not great..
origEvent := event
unsignedUpdated := false
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember && event.StateKey() != nil {
unsignedUpdated, err = t.updateUnsignedIfNeeded(ctx, event)
if err != nil {
results[event.EventID()] = fclient.PDUResult{
Error: err.Error(),
}
continue
}
}
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, t.rsAPI); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = fclient.PDUResult{ results[event.EventID()] = fclient.PDUResult{
Error: err.Error(), Error: err.Error(),
@ -193,11 +178,6 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
continue continue
} }
// switch the event again, so we don't store a wrong value in the DB
if unsignedUpdated {
event = origEvent
}
// pass the event to the roomserver which will do auth checks // pass the event to the roomserver which will do auth checks
// If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
// discarded by the caller of this function // discarded by the caller of this function

View file

@ -236,6 +236,7 @@ type FederationRoomserverAPI interface {
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
UserRoomPrivateKeyCreator UserRoomPrivateKeyCreator
CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.

View file

@ -780,7 +780,7 @@ nextAuthEvent:
// if a critical event is missing anyway. // if a critical event is missing anyway.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, r.Queryer); err != nil {
continue nextAuthEvent continue nextAuthEvent
} }

View file

@ -571,7 +571,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, t.inputer.Queryer); err != nil {
continue continue
} }
missingEvents = append(missingEvents, t.cacheAndReturn(ev)) missingEvents = append(missingEvents, t.cacheAndReturn(ev))
@ -662,7 +662,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
AuthEvents: state.GetAuthEvents(), AuthEvents: state.GetAuthEvents(),
}, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { }, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}) }, t.inputer.Queryer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -899,7 +899,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
} }
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, t.inputer.Queryer); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }

View file

@ -270,7 +270,7 @@ func (r *Admin) PerformAdminDownloadState(
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, r.Queryer); err != nil {
continue continue
} }
authEventMap[authEvent.EventID()] = authEvent authEventMap[authEvent.EventID()] = authEvent
@ -278,7 +278,7 @@ func (r *Admin) PerformAdminDownloadState(
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, r.Queryer); err != nil {
continue continue
} }
stateEventMap[stateEvent.EventID()] = stateEvent stateEventMap[stateEvent.EventID()] = stateEvent

View file

@ -266,6 +266,7 @@ type backfillRequester struct {
eventIDMap map[string]gomatrixserverlib.PDU eventIDMap map[string]gomatrixserverlib.PDU
historyVisiblity gomatrixserverlib.HistoryVisibility historyVisiblity gomatrixserverlib.HistoryVisibility
roomVersion gomatrixserverlib.RoomVersion roomVersion gomatrixserverlib.RoomVersion
membershipQuerier gomatrixserverlib.MembershipQuerier
} }
func newBackfillRequester( func newBackfillRequester(
@ -292,9 +293,14 @@ func newBackfillRequester(
preferServer: preferServer, preferServer: preferServer,
historyVisiblity: gomatrixserverlib.HistoryVisibilityShared, historyVisiblity: gomatrixserverlib.HistoryVisibilityShared,
roomVersion: roomVersion, roomVersion: roomVersion,
membershipQuerier: fsAPI,
} }
} }
func (b *backfillRequester) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
return b.fsAPI.CurrentMembership(ctx, roomID, senderID)
}
func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) { func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) {
b.eventIDMap[targetEvent.EventID()] = targetEvent b.eventIDMap[targetEvent.EventID()] = targetEvent
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {

View file

@ -1036,3 +1036,14 @@ func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID,
return nil, nil return nil, nil
} }
func (r *Queryer) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
res := api.QueryMembershipForUserResponse{}
err := r.QueryMembershipForSenderID(ctx, roomID, senderID, &res)
membership := ""
if err == nil {
membership = res.Membership
}
return membership, err
}