Compare commits

...

24 commits

Author SHA1 Message Date
Devon Hudson cae2b61c88
Update gmsl 2023-07-21 11:41:59 -06:00
Devon Hudson 0aa1b6c218
Merge branch 'main' into s7evink/memberships 2023-07-21 11:37:50 -06:00
Till Faelligen 16075ce657
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/memberships 2023-07-13 14:33:44 +02:00
Till Faelligen 5a7d36d1b3
Only leave/kick if the redacted event is the current state event 2023-07-12 14:50:15 +02:00
Till Faelligen 92a633d0e4
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/memberships 2023-07-12 11:17:57 +02:00
Till Faelligen 6942c198b9
Update GMSL 2023-07-12 11:16:06 +02:00
Devon Hudson 4e27ff28b8
Merge changes from main 2023-07-06 17:58:17 -06:00
Devon Hudson f825ce2935
Fix gmsl to point to memberships branch 2023-07-06 15:45:19 -06:00
Devon Hudson 2bcb89ad4c
Merge branch 'main' into s7evink/memberships 2023-07-06 15:41:27 -06:00
Till Faelligen 63f239f336
Add comment for MaybeRedactEvent 2023-07-05 11:02:12 +02:00
Till Faelligen b93a9e4615
Remove unused method 2023-07-05 11:01:12 +02:00
Till Faelligen 9569498761
Update GMSL 2023-07-05 10:59:42 +02:00
Till Faelligen 9ddd62925c
Use a MembershipQuerier to get the current membership 2023-07-05 10:57:41 +02:00
Till Faelligen 3085928906
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/memberships 2023-07-05 08:15:05 +02:00
Till Faelligen fa4f7021a1
Close the stmt after usage 2023-07-03 09:04:26 +02:00
Till Faelligen da51f32e03
Update GMSL 2023-07-03 09:01:43 +02:00
Till Faelligen 4bf57a2519
Verify the mxid_mapping only if it is an actual join event 2023-07-03 08:59:13 +02:00
Till Faelligen dcd28e3614
Kick users on redaction of join events 2023-07-03 08:54:18 +02:00
Till Faelligen 641bac0ce5
Remove mxid_mapping before storing it in unsigned.prev_content 2023-06-30 14:32:46 +02:00
Till Faelligen 552eaf2940
Redact the event in the current room state table 2023-06-30 14:32:24 +02:00
Till Faelligen 8b5afcf680
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/memberships 2023-06-30 11:15:43 +02:00
Till Faelligen 302d8d7089
Fix getting prev content 2023-06-29 11:35:15 +02:00
Till Faelligen 2b3b355ebd
Fix getting the current state event 2023-06-29 11:34:37 +02:00
Till Faelligen bc8e83fd28
Add possibility to query all user keys; Get all joined rooms 2023-06-29 11:33:28 +02:00
25 changed files with 346 additions and 64 deletions

View file

@ -74,6 +74,7 @@ type RoomserverFederationAPI interface {
GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error)
GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error)
CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
RoomHierarchies(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.RoomHierarchyResponse, err error) RoomHierarchies(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.RoomHierarchyResponse, err error)
} }

View file

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -185,3 +186,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted(
} }
return request() return request()
} }
func (a *FederationInternalAPI) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
return a.rsAPI.CurrentMembership(ctx, roomID, senderID)
}

View file

@ -167,7 +167,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
KeyRing: r.keyRing, KeyRing: r.keyRing,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }, r.rsAPI),
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
@ -190,6 +190,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
} }
return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID) return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
}, },
MembershipQuerier: r.rsAPI,
} }
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
@ -387,7 +388,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
} }
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider, ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider, r.rsAPI), userIDProvider, r.rsAPI,
) )
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
@ -728,7 +729,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error {
func federatedEventProvider( func federatedEventProvider(
ctx context.Context, federation fclient.FederationClient, ctx context.Context, federation fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
userIDForSender spec.UserIDForSender, userIDForSender spec.UserIDForSender, rsAPI gomatrixserverlib.MembershipQuerier,
) gomatrixserverlib.EventProvider { ) gomatrixserverlib.EventProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
@ -778,7 +779,7 @@ func federatedEventProvider(
} }
// Check the signatures of the event. // Check the signatures of the event.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender, rsAPI); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
} }

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721154317-b5b0448aa378 github.com/matrix-org/gomatrixserverlib v0.0.0-20230721173823-6788b4fb4400
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.17 github.com/mattn/go-sqlite3 v1.14.17

4
go.sum
View file

@ -207,8 +207,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721154317-b5b0448aa378 h1:a6sfiJiNZWVbPRHvEB/YlpqSg+Dh7El+824mzccSk68= github.com/matrix-org/gomatrixserverlib v0.0.0-20230721173823-6788b4fb4400 h1:wTtUS3rjADu788S071rgT2Twg7uKsWt6hrq0cFu3zmo=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721154317-b5b0448aa378/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230721173823-6788b4fb4400/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=

View file

@ -167,9 +167,10 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
} }
continue continue
} }
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, t.rsAPI); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = fclient.PDUResult{ results[event.EventID()] = fclient.PDUResult{
Error: err.Error(), Error: err.Error(),

View file

@ -261,6 +261,7 @@ type FederationRoomserverAPI interface {
QuerySenderIDAPI QuerySenderIDAPI
QueryRoomHierarchyAPI QueryRoomHierarchyAPI
UserRoomPrivateKeyCreator UserRoomPrivateKeyCreator
CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.

View file

@ -418,10 +418,11 @@ func (r *Inputer) processRoomEvent(
redactedEventID string redactedEventID string
redactionEvent gomatrixserverlib.PDU redactionEvent gomatrixserverlib.PDU
redactedEvent gomatrixserverlib.PDU redactedEvent gomatrixserverlib.PDU
originalRedactedEvent gomatrixserverlib.PDU
) )
if !isRejected && !isCreateEvent { if !isRejected && !isCreateEvent {
resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer) resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer)
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver, r.Queryer) redactionEvent, redactedEvent, originalRedactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver, r.Queryer)
if err != nil { if err != nil {
return err return err
} }
@ -519,6 +520,12 @@ func (r *Inputer) processRoomEvent(
if err != nil { if err != nil {
return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
} }
// if we're in a pseudoID room, and we redacted a m.room.member event, also leave/kick the user
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && redactedEvent.Type() == spec.MRoomMember {
if err = r.leavePseudoIDRoom(ctx, *validRoomID, originalRedactedEvent, redactionEvent); err != nil {
logrus.WithError(err).Error("failed to leave user after membership event redaction")
}
}
} }
// If guest_access changed and is not can_join, kick all guest users. // If guest_access changed and is not can_join, kick all guest users.
@ -534,6 +541,81 @@ func (r *Inputer) processRoomEvent(
return nil return nil
} }
// leavePseudoIDRoom leaves/kicks a user in the event of a membership event redaction.
// TODO: This doesn't play well with users re-joining rooms, as in this case we have multiple join events with a mxid_mapping.
func (r *Inputer) leavePseudoIDRoom(ctx context.Context, roomID spec.RoomID, originalRedactedEvent, redactionEvent gomatrixserverlib.PDU) error {
stateKey := originalRedactedEvent.StateKey()
currentStateEvent, err := r.DB.GetStateEvent(ctx, roomID.String(), originalRedactedEvent.Type(), *stateKey)
if err != nil {
return err
}
// If the redacted event is NOT the current state event, do nothing
if currentStateEvent.EventID() != originalRedactedEvent.EventID() {
return nil
}
var memberContent gomatrixserverlib.MemberContent
if err = json.Unmarshal(originalRedactedEvent.Content(), &memberContent); err != nil {
return err
}
if memberContent.Membership != spec.Join {
return nil
}
// no mxid_mapping, nothing to do
if memberContent.MXIDMapping == nil {
return nil
}
userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, redactionEvent.SenderID())
if err != nil {
return err
}
// We can only create the leave event on servers the redaction originated on.
// We are going to receive the leave event anyway.
if !r.Cfg.Matrix.IsLocalServerName(userID.Domain()) {
return nil
}
signingIdentity, err := r.SigningIdentity(ctx, roomID, *userID)
if err != nil {
return err
}
fledglingEvent := &gomatrixserverlib.ProtoEvent{
RoomID: originalRedactedEvent.RoomID(),
Type: spec.MRoomMember,
StateKey: stateKey,
SenderID: string(redactionEvent.SenderID()),
}
if fledglingEvent.Content, err = json.Marshal(gomatrixserverlib.MemberContent{
Membership: spec.Leave,
}); err != nil {
return err
}
event, err := eventutil.QueryAndBuildEvent(ctx, fledglingEvent, &signingIdentity, time.Now(), r.Queryer, nil)
if err != nil {
return err
}
inputReq := &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{{
Kind: api.KindNew,
Event: event,
Origin: userID.Domain(),
SendAsServer: string(userID.Domain()),
}},
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
}
inputRes := &api.InputRoomEventsResponse{}
r.InputRoomEvents(ctx, inputReq, inputRes)
return nil
}
// handleRemoteRoomUpgrade updates published rooms and room aliases // handleRemoteRoomUpgrade updates published rooms and room aliases
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
oldRoomID := event.RoomID() oldRoomID := event.RoomID()
@ -739,7 +821,7 @@ nextAuthEvent:
// if a critical event is missing anyway. // if a critical event is missing anyway.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, r.Queryer); err != nil {
continue nextAuthEvent continue nextAuthEvent
} }

View file

@ -599,7 +599,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, t.inputer.Queryer); err != nil {
continue continue
} }
missingEvents = append(missingEvents, t.cacheAndReturn(ev)) missingEvents = append(missingEvents, t.cacheAndReturn(ev))
@ -690,7 +690,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
AuthEvents: state.GetAuthEvents(), AuthEvents: state.GetAuthEvents(),
}, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { }, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}) }, t.inputer.Queryer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -946,7 +946,7 @@ serverLoop:
} }
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, t.inputer.Queryer); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }

View file

@ -270,7 +270,7 @@ func (r *Admin) PerformAdminDownloadState(
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, r.Queryer); err != nil {
continue continue
} }
authEventMap[authEvent.EventID()] = authEvent authEventMap[authEvent.EventID()] = authEvent
@ -278,7 +278,7 @@ func (r *Admin) PerformAdminDownloadState(
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }, r.Queryer); err != nil {
continue continue
} }
stateEventMap[stateEvent.EventID()] = stateEvent stateEventMap[stateEvent.EventID()] = stateEvent

View file

@ -266,6 +266,7 @@ type backfillRequester struct {
eventIDMap map[string]gomatrixserverlib.PDU eventIDMap map[string]gomatrixserverlib.PDU
historyVisiblity gomatrixserverlib.HistoryVisibility historyVisiblity gomatrixserverlib.HistoryVisibility
roomVersion gomatrixserverlib.RoomVersion roomVersion gomatrixserverlib.RoomVersion
membershipQuerier gomatrixserverlib.MembershipQuerier
} }
func newBackfillRequester( func newBackfillRequester(
@ -292,9 +293,14 @@ func newBackfillRequester(
preferServer: preferServer, preferServer: preferServer,
historyVisiblity: gomatrixserverlib.HistoryVisibilityShared, historyVisiblity: gomatrixserverlib.HistoryVisibilityShared,
roomVersion: roomVersion, roomVersion: roomVersion,
membershipQuerier: fsAPI,
} }
} }
func (b *backfillRequester) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
return b.fsAPI.CurrentMembership(ctx, roomID, senderID)
}
func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) { func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) {
b.eventIDMap[targetEvent.EventID()] = targetEvent b.eventIDMap[targetEvent.EventID()] = targetEvent
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {
@ -647,7 +653,7 @@ func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySe
resolver := state.NewStateResolution(db, roomInfo, querier) resolver := state.NewStateResolution(db, roomInfo, querier)
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver, querier) _, redactedEvent, _, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver, querier)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
continue continue

View file

@ -1052,3 +1052,14 @@ func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID,
return nil, nil return nil, nil
} }
func (r *Queryer) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
res := api.QueryMembershipForUserResponse{}
err := r.QueryMembershipForSenderID(ctx, roomID, senderID, &res)
membership := ""
if err == nil {
membership = res.Membership
}
return membership, err
}

View file

@ -575,7 +575,7 @@ func TestRedaction(t *testing.T) {
err = updater.Commit() err = updater.Commit()
assert.NoError(t, err) assert.NoError(t, err)
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver, &FakeQuerier{}) _, redactedEvent, _, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver, &FakeQuerier{})
assert.NoError(t, err) assert.NoError(t, err)
if redactedEvent != nil { if redactedEvent != nil {
assert.Equal(t, ev.Redacts(), redactedEvent.EventID()) assert.Equal(t, ev.Redacts(), redactedEvent.EventID())

View file

@ -190,9 +190,9 @@ type Database interface {
GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
MaybeRedactEvent( // MaybeRedactEvent returns the redaction event, the redacted event and the event before redaction if this call resulted in a redaction, else an error
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI, // (nil if there was nothing to do)
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) MaybeRedactEvent(ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
} }
type UserRoomKeys interface { type UserRoomKeys interface {
@ -249,10 +249,8 @@ type EventDatabase interface {
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error // MaybeRedactEvent returns the redaction event, the redacted event and the event before redaction if this call resulted in a redaction, else an error
// (nil if there was nothing to do) // (nil if there was nothing to do)
MaybeRedactEvent( MaybeRedactEvent(ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI) (redactionEvent, redactedEvent, originalRedactedEvent gomatrixserverlib.PDU, err error)
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
} }

View file

@ -128,7 +128,7 @@ const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2" "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
const selectRoomsWithMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = ANY($2) and forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will // joined to. Since this information is used to populate the user directory, we will
@ -347,10 +347,10 @@ func (s *membershipStatements) UpdateMembership(
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, membershipState tables.MembershipState, userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID) rows, err := stmt.QueryContext(ctx, membershipState, pq.Array(userIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt
selectUserRoomKeysStmt *sql.Stmt
} }
func CreateUserRoomKeysTable(db *sql.DB) error { func CreateUserRoomKeysTable(db *sql.DB) error {
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL},
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -150,3 +154,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
rows, err := stmt.QueryContext(ctx, userNID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
var result []ed25519.PrivateKey
var pk ed25519.PrivateKey
for rows.Next() {
if err = rows.Scan(&pk); err != nil {
return nil, err
}
result = append(result, pk)
}
return result, rows.Err()
}

View file

@ -990,13 +990,13 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) (
// to cross-reference with other tables when loading. // to cross-reference with other tables when loading.
// //
// Returns the redaction event and the redacted event if this call resulted in a redaction. // Returns the redaction event and the redacted event if this call resulted in a redaction.
// nolint: gocylo
func (d *EventDatabase) MaybeRedactEvent( func (d *EventDatabase) MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
querier api.QuerySenderIDAPI, querier api.QuerySenderIDAPI,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) { ) (redactionEvent, redactedEvent, originalRedactedEvent gomatrixserverlib.PDU, err error) {
var ( var (
redactionEvent, redactedEvent *types.Event redactionEv, redactedEv *types.Event
err error
validated bool validated bool
ignoreRedaction bool ignoreRedaction bool
) )
@ -1019,42 +1019,42 @@ func (d *EventDatabase) MaybeRedactEvent(
} }
} }
redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event) redactionEv, redactedEv, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event)
switch { switch {
case err != nil: case err != nil:
return fmt.Errorf("d.loadRedactionPair: %w", err) return fmt.Errorf("d.loadRedactionPair: %w", err)
case validated || redactedEvent == nil || redactionEvent == nil: case validated || redactedEv == nil || redactionEv == nil:
// we've seen this redaction before or there is nothing to redact // we've seen this redaction before or there is nothing to redact
return nil return nil
case redactedEvent.RoomID() != redactionEvent.RoomID(): case redactedEv.RoomID() != redactionEv.RoomID():
// redactions across rooms aren't allowed // redactions across rooms aren't allowed
ignoreRedaction = true ignoreRedaction = true
return nil return nil
} }
var validRoomID *spec.RoomID var validRoomID *spec.RoomID
validRoomID, err = spec.NewRoomID(redactedEvent.RoomID()) validRoomID, err = spec.NewRoomID(redactedEv.RoomID())
if err != nil { if err != nil {
return err return err
} }
sender1Domain := "" sender1Domain := ""
sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEvent.SenderID()) sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEv.SenderID())
if err1 == nil { if err1 == nil {
sender1Domain = string(sender1.Domain()) sender1Domain = string(sender1.Domain())
} }
sender2Domain := "" sender2Domain := ""
sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEv.SenderID())
if err2 == nil { if err2 == nil {
sender2Domain = string(sender2.Domain()) sender2Domain = string(sender2.Domain())
} }
var powerlevels *gomatrixserverlib.PowerLevelContent var powerlevels *gomatrixserverlib.PowerLevelContent
powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) powerlevels, err = plResolver.Resolve(ctx, redactionEv.EventID())
if err != nil { if err != nil {
return err return err
} }
switch { switch {
case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: case powerlevels.UserLevel(redactionEv.SenderID()) >= powerlevels.Redact:
// 1. The power level of the redaction events sender is greater than or equal to the redact level. // 1. The power level of the redaction events sender is greater than or equal to the redact level.
case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain: case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain:
// 2. The domain of the redaction events sender matches that of the original events sender. // 2. The domain of the redaction events sender matches that of the original events sender.
@ -1065,42 +1065,47 @@ func (d *EventDatabase) MaybeRedactEvent(
// mark the event as redacted // mark the event as redacted
if redactionsArePermanent { if redactionsArePermanent {
redactedEvent.Redact() originalRedactedEvent, err = gomatrixserverlib.MustGetRoomVersion(redactedEv.Version()).
NewEventFromTrustedJSON(redactedEv.JSON(), false)
if err != nil {
return err
}
redactedEv.Redact()
} }
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) err = redactedEv.SetUnsignedField("redacted_because", redactionEv)
if err != nil { if err != nil {
return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
} }
// NOTSPEC: sytest relies on this unspecced field existing :( // NOTSPEC: sytest relies on this unspecced field existing :(
err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) err = redactedEv.SetUnsignedField("redacted_by", redactionEv.EventID())
if err != nil { if err != nil {
return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
} }
// overwrite the eventJSON table // overwrite the eventJSON table
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEv.EventNID, redactedEv.JSON())
if err != nil { if err != nil {
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEv.EventID(), true)
if err != nil { if err != nil {
return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
} }
// We remove the entry from the cache, as if we just "StoreRoomServerEvent", we can't be // We remove the entry from the cache, as if we just "StoreRoomServerEvent", we can't be
// certain that the cached entry actually is updated, since ristretto is eventual-persistent. // certain that the cached entry actually is updated, since ristretto is eventual-persistent.
d.Cache.InvalidateRoomServerEvent(redactedEvent.EventNID) d.Cache.InvalidateRoomServerEvent(redactedEv.EventNID)
return nil return nil
}) })
if wErr != nil { if wErr != nil {
return nil, nil, err return nil, nil, nil, err
} }
if ignoreRedaction || redactionEvent == nil || redactedEvent == nil { if ignoreRedaction || redactionEv == nil || redactedEv == nil {
return nil, nil, nil return nil, nil, nil, nil
} }
return redactionEvent.PDU, redactedEvent.PDU, nil return redactionEv.PDU, redactedEv.PDU, originalRedactedEvent, nil
} }
// loadRedactionPair returns both the redaction event and the redacted event, else nil. // loadRedactionPair returns both the redaction event and the redacted event, else nil.
@ -1361,14 +1366,38 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
default: default:
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
} }
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) stateKeyNID, err := d.EventStateKeyNIDs(ctx, []string{userID})
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
} }
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
// get the pseudo IDs, if any, as otherwise we don't get the correct room list
pseudoIDKeys, err := d.UserRoomKeyTable.SelectPrivateKeysForUserNID(ctx, nil, stateKeyNID[userID])
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectPrivateKeysForUserNID: %w", err)
}
senderIDs := make([]string, len(pseudoIDKeys))
var senderID spec.SenderID
for _, key := range pseudoIDKeys {
senderID = spec.SenderIDFromPseudoIDKey(key)
senderIDs = append(senderIDs, string(senderID))
}
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, senderIDs)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to EventStateKeyNIDs: %w", err)
}
stateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)+1)
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID[userID])
for _, stateKeyNID := range stateKeyNIDMap {
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID)
}
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNIDs, membershipState)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
} }

View file

@ -100,7 +100,7 @@ const updateMembershipForgetRoom = "" +
" WHERE room_nid = $2 AND target_nid = $3" " WHERE room_nid = $2 AND target_nid = $3"
const selectRoomsWithMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid IN ($2) and forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will // joined to. Since this information is used to populate the user directory, we will
@ -297,10 +297,28 @@ func (s *membershipStatements) UpdateMembership(
} }
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState, ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID) query := strings.Replace(selectRoomsWithMembershipSQL, "($2)", sqlutil.QueryVariadicOffset(len(userIDs), 1), 1)
var stmt *sql.Stmt
var err error
if txn != nil {
stmt, err = txn.PrepareContext(ctx, query)
} else {
stmt, err = s.db.Prepare(query)
}
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectRoomsWithMembership: stmt.close() failed")
params := make([]any, len(userIDs)+1)
params[0] = membershipState
for i, userID := range userIDs {
params[i+1] = userID
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -56,6 +56,8 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
db *sql.DB db *sql.DB
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
@ -63,6 +65,7 @@ type userRoomKeysStatements struct {
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime //selectUserNIDsStmt *sql.Stmt //prepared at runtime
selectUserRoomKeysStmt *sql.Stmt
} }
func CreateUserRoomKeysTable(db *sql.DB) error { func CreateUserRoomKeysTable(db *sql.DB) error {
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
@ -165,3 +169,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
rows, err := stmt.QueryContext(ctx, userNID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
var result []ed25519.PrivateKey
var pk ed25519.PrivateKey
for rows.Next() {
if err = rows.Scan(&pk); err != nil {
return nil, err
}
result = append(result, pk)
}
return result, rows.Err()
}

View file

@ -142,7 +142,7 @@ type Membership interface {
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
@ -198,6 +198,8 @@ type UserRoomKeys interface {
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
// If a senderKey can't be found, it is omitted in the result. // If a senderKey can't be found, it is omitted in the result.
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
// SelectRoomNIDs selects all roomNIDs for a specific user
SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -99,12 +99,12 @@ func TestMembershipTable(t *testing.T) {
assert.Equal(t, 10, len(members)) assert.Equal(t, 10, len(members))
// Get correct user // Get correct user
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, userNIDs[1], tables.MembershipStateLeaveOrBan) roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[1]}, tables.MembershipStateLeaveOrBan)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []types.RoomNID{1}, roomNIDs) assert.Equal(t, []types.RoomNID{1}, roomNIDs)
// User is not joined to room // User is not joined to room
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, userNIDs[5], tables.MembershipStateJoin) roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[5]}, tables.MembershipStateJoin)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 0, len(roomNIDs)) assert.Equal(t, 0, len(roomNIDs))

View file

@ -115,6 +115,11 @@ func TestUserRoomKeysTable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, key4, gotPublicKey) assert.Equal(t, key4, gotPublicKey)
// query rooms for a specific user
var pks []ed25519.PrivateKey
pks, err = tab.SelectPrivateKeysForUserNID(context.Background(), txn, userNID)
assert.NoError(t, err)
assert.Equal(t, []ed25519.PrivateKey{key}, pks)
return nil return nil
}) })
assert.NoError(t, err) assert.NoError(t, err)

View file

@ -34,11 +34,13 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
@ -592,6 +594,15 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
return event, nil return event, nil
} }
prevContent := prevEvent.Content()
// if we're storing a pseudoID event, make sure to delete the mxid_mapping
if event.Type() == spec.MRoomMember && event.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
prevContent, err = sjson.DeleteBytes(prevEvent.Content(), "mxid_mapping")
if err != nil {
return event, err
}
}
prevEventSender := string(prevEvent.SenderID()) prevEventSender := string(prevEvent.SenderID())
prevUser, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, prevEvent.SenderID()) prevUser, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, prevEvent.SenderID())
if err == nil && prevUser != nil { if err == nil && prevUser != nil {
@ -599,7 +610,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
} }
prev := types.PrevEventRef{ prev := types.PrevEventRef{
PrevContent: prevEvent.Content(), PrevContent: prevContent,
ReplacesState: prevEvent.EventID(), ReplacesState: prevEvent.EventID(),
PrevSenderID: prevEventSender, PrevSenderID: prevEventSender,
} }

View file

@ -381,11 +381,50 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
newEvent := &rstypes.HeaderedEvent{PDU: eventToRedact} newEvent := &rstypes.HeaderedEvent{PDU: eventToRedact}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// if we are redacting a state event, also update the current_room_state table
if newEvent.StateKey() != nil {
if err = d.redactCurrentStateEvent(ctx, txn, newEvent, querier); err != nil {
return err
}
}
return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent) return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
}) })
return err return err
} }
// redactCurrentStateEvent updates the JSON data in the current_room_state table
func (d *Database) redactCurrentStateEvent(ctx context.Context, txn *sql.Tx, newEvent *rstypes.HeaderedEvent, querier api.QuerySenderIDAPI) error {
// resolve the state key, which may be user pseudoID
if *newEvent.StateKey() != "" {
validRoomID, err := spec.NewRoomID(newEvent.RoomID())
if err != nil {
return err
}
var sku *spec.UserID
stateKey := newEvent.StateKey()
sku, err = querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
if err == nil && sku != nil {
sKey := sku.String()
newEvent.StateKeyResolved = &sKey
}
}
// get the current stream position of the event
streamEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, []string{newEvent.EventID()})
if err == nil && len(streamEvents) > 0 {
var membershipPtr *string
var membership string
membership, err = streamEvents[0].Membership()
if err == nil {
membershipPtr = &membership
}
if err = d.CurrentRoomState.UpsertRoomState(ctx, txn, newEvent, membershipPtr, streamEvents[0].StreamPosition); err != nil {
return err
}
}
return nil
}
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
// Returns a map of room ID to list of events. // Returns a map of room ID to list of events.
func (d *Database) fetchStateEvents( func (d *Database) fetchStateEvents(

View file

@ -1030,7 +1030,7 @@ func TestRedaction(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) redactedEvent := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "join", "displayname": "alice"}, test.WithStateKey(alice.ID))
redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()}) redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()})
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
@ -1064,5 +1064,25 @@ func TestRedaction(t *testing.T) {
if depth.Exists() { if depth.Exists() {
t.Error("unexpected auth_events in redacted event") t.Error("unexpected auth_events in redacted event")
} }
dbTxn, err := db.NewDatabaseTransaction(context.Background())
if err != nil {
t.Fatal(err)
}
filter := synctypes.DefaultStateFilter()
wantTypes := []string{spec.MRoomMember}
filter.Types = &wantTypes
evs, err = dbTxn.CurrentRoomState.SelectCurrentState(context.Background(), nil, redactedEvent.RoomID(), &filter, nil)
if err != nil {
t.Fatal(err)
}
if count := len(evs); count != 1 {
t.Fatalf("expected 1 event, got %d", count)
}
// we expect that the displayname does not exist anymore
displayname := gjson.GetBytes(evs[0].Content(), "displayname")
if displayname.Exists() {
t.Fatal("expected displayname to be redacted, but wasn't")
}
}) })
} }