Compare commits
24 commits
main
...
s7evink/me
Author | SHA1 | Date | |
---|---|---|---|
cae2b61c88 | |||
0aa1b6c218 | |||
16075ce657 | |||
5a7d36d1b3 | |||
92a633d0e4 | |||
6942c198b9 | |||
4e27ff28b8 | |||
f825ce2935 | |||
2bcb89ad4c | |||
63f239f336 | |||
b93a9e4615 | |||
9569498761 | |||
9ddd62925c | |||
3085928906 | |||
fa4f7021a1 | |||
da51f32e03 | |||
4bf57a2519 | |||
dcd28e3614 | |||
641bac0ce5 | |||
552eaf2940 | |||
8b5afcf680 | |||
302d8d7089 | |||
2b3b355ebd | |||
bc8e83fd28 |
|
@ -74,6 +74,7 @@ type RoomserverFederationAPI interface {
|
||||||
GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error)
|
GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error)
|
||||||
GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||||
LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error)
|
LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error)
|
||||||
|
CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
|
||||||
|
|
||||||
RoomHierarchies(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.RoomHierarchyResponse, err error)
|
RoomHierarchies(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.RoomHierarchyResponse, err error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -185,3 +186,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted(
|
||||||
}
|
}
|
||||||
return request()
|
return request()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *FederationInternalAPI) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
|
||||||
|
return a.rsAPI.CurrentMembership(ctx, roomID, senderID)
|
||||||
|
}
|
||||||
|
|
|
@ -167,7 +167,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
KeyRing: r.keyRing,
|
KeyRing: r.keyRing,
|
||||||
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}),
|
}, r.rsAPI),
|
||||||
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
|
@ -190,6 +190,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
}
|
}
|
||||||
return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
||||||
},
|
},
|
||||||
|
MembershipQuerier: r.rsAPI,
|
||||||
}
|
}
|
||||||
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
|
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
|
||||||
|
|
||||||
|
@ -387,7 +388,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
|
||||||
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}
|
}
|
||||||
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
|
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
|
||||||
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider,
|
ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider, r.rsAPI), userIDProvider, r.rsAPI,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error checking state returned from peeking: %w", err)
|
return fmt.Errorf("error checking state returned from peeking: %w", err)
|
||||||
|
@ -728,7 +729,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error {
|
||||||
func federatedEventProvider(
|
func federatedEventProvider(
|
||||||
ctx context.Context, federation fclient.FederationClient,
|
ctx context.Context, federation fclient.FederationClient,
|
||||||
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
|
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
|
||||||
userIDForSender spec.UserIDForSender,
|
userIDForSender spec.UserIDForSender, rsAPI gomatrixserverlib.MembershipQuerier,
|
||||||
) gomatrixserverlib.EventProvider {
|
) gomatrixserverlib.EventProvider {
|
||||||
// A list of events that we have retried, if they were not included in
|
// A list of events that we have retried, if they were not included in
|
||||||
// the auth events supplied in the send_join.
|
// the auth events supplied in the send_join.
|
||||||
|
@ -778,7 +779,7 @@ func federatedEventProvider(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the signatures of the event.
|
// Check the signatures of the event.
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil {
|
if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender, rsAPI); err != nil {
|
||||||
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
|
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -22,7 +22,7 @@ require (
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721154317-b5b0448aa378
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721173823-6788b4fb4400
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
||||||
github.com/mattn/go-sqlite3 v1.14.17
|
github.com/mattn/go-sqlite3 v1.14.17
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -207,8 +207,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721154317-b5b0448aa378 h1:a6sfiJiNZWVbPRHvEB/YlpqSg+Dh7El+824mzccSk68=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721173823-6788b4fb4400 h1:wTtUS3rjADu788S071rgT2Twg7uKsWt6hrq0cFu3zmo=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721154317-b5b0448aa378/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230721173823-6788b4fb4400/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
||||||
|
|
|
@ -167,9 +167,10 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}, t.rsAPI); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
||||||
results[event.EventID()] = fclient.PDUResult{
|
results[event.EventID()] = fclient.PDUResult{
|
||||||
Error: err.Error(),
|
Error: err.Error(),
|
||||||
|
|
|
@ -261,6 +261,7 @@ type FederationRoomserverAPI interface {
|
||||||
QuerySenderIDAPI
|
QuerySenderIDAPI
|
||||||
QueryRoomHierarchyAPI
|
QueryRoomHierarchyAPI
|
||||||
UserRoomPrivateKeyCreator
|
UserRoomPrivateKeyCreator
|
||||||
|
CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
|
||||||
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
|
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
|
||||||
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
|
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
|
||||||
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||||
|
|
|
@ -415,13 +415,14 @@ func (r *Inputer) processRoomEvent(
|
||||||
// if storing this event results in it being redacted then do so.
|
// if storing this event results in it being redacted then do so.
|
||||||
// we do this after calculating state for this event as we may need to get power levels
|
// we do this after calculating state for this event as we may need to get power levels
|
||||||
var (
|
var (
|
||||||
redactedEventID string
|
redactedEventID string
|
||||||
redactionEvent gomatrixserverlib.PDU
|
redactionEvent gomatrixserverlib.PDU
|
||||||
redactedEvent gomatrixserverlib.PDU
|
redactedEvent gomatrixserverlib.PDU
|
||||||
|
originalRedactedEvent gomatrixserverlib.PDU
|
||||||
)
|
)
|
||||||
if !isRejected && !isCreateEvent {
|
if !isRejected && !isCreateEvent {
|
||||||
resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer)
|
resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer)
|
||||||
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver, r.Queryer)
|
redactionEvent, redactedEvent, originalRedactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver, r.Queryer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -519,6 +520,12 @@ func (r *Inputer) processRoomEvent(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
|
return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
|
||||||
}
|
}
|
||||||
|
// if we're in a pseudoID room, and we redacted a m.room.member event, also leave/kick the user
|
||||||
|
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && redactedEvent.Type() == spec.MRoomMember {
|
||||||
|
if err = r.leavePseudoIDRoom(ctx, *validRoomID, originalRedactedEvent, redactionEvent); err != nil {
|
||||||
|
logrus.WithError(err).Error("failed to leave user after membership event redaction")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If guest_access changed and is not can_join, kick all guest users.
|
// If guest_access changed and is not can_join, kick all guest users.
|
||||||
|
@ -534,6 +541,81 @@ func (r *Inputer) processRoomEvent(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// leavePseudoIDRoom leaves/kicks a user in the event of a membership event redaction.
|
||||||
|
// TODO: This doesn't play well with users re-joining rooms, as in this case we have multiple join events with a mxid_mapping.
|
||||||
|
func (r *Inputer) leavePseudoIDRoom(ctx context.Context, roomID spec.RoomID, originalRedactedEvent, redactionEvent gomatrixserverlib.PDU) error {
|
||||||
|
|
||||||
|
stateKey := originalRedactedEvent.StateKey()
|
||||||
|
currentStateEvent, err := r.DB.GetStateEvent(ctx, roomID.String(), originalRedactedEvent.Type(), *stateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// If the redacted event is NOT the current state event, do nothing
|
||||||
|
if currentStateEvent.EventID() != originalRedactedEvent.EventID() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var memberContent gomatrixserverlib.MemberContent
|
||||||
|
if err = json.Unmarshal(originalRedactedEvent.Content(), &memberContent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if memberContent.Membership != spec.Join {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// no mxid_mapping, nothing to do
|
||||||
|
if memberContent.MXIDMapping == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, redactionEvent.SenderID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can only create the leave event on servers the redaction originated on.
|
||||||
|
// We are going to receive the leave event anyway.
|
||||||
|
if !r.Cfg.Matrix.IsLocalServerName(userID.Domain()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
signingIdentity, err := r.SigningIdentity(ctx, roomID, *userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fledglingEvent := &gomatrixserverlib.ProtoEvent{
|
||||||
|
RoomID: originalRedactedEvent.RoomID(),
|
||||||
|
Type: spec.MRoomMember,
|
||||||
|
StateKey: stateKey,
|
||||||
|
SenderID: string(redactionEvent.SenderID()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if fledglingEvent.Content, err = json.Marshal(gomatrixserverlib.MemberContent{
|
||||||
|
Membership: spec.Leave,
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
event, err := eventutil.QueryAndBuildEvent(ctx, fledglingEvent, &signingIdentity, time.Now(), r.Queryer, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputReq := &api.InputRoomEventsRequest{
|
||||||
|
InputRoomEvents: []api.InputRoomEvent{{
|
||||||
|
Kind: api.KindNew,
|
||||||
|
Event: event,
|
||||||
|
Origin: userID.Domain(),
|
||||||
|
SendAsServer: string(userID.Domain()),
|
||||||
|
}},
|
||||||
|
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
|
||||||
|
}
|
||||||
|
inputRes := &api.InputRoomEventsResponse{}
|
||||||
|
r.InputRoomEvents(ctx, inputReq, inputRes)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// handleRemoteRoomUpgrade updates published rooms and room aliases
|
// handleRemoteRoomUpgrade updates published rooms and room aliases
|
||||||
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
|
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
|
||||||
oldRoomID := event.RoomID()
|
oldRoomID := event.RoomID()
|
||||||
|
@ -739,7 +821,7 @@ nextAuthEvent:
|
||||||
// if a critical event is missing anyway.
|
// if a critical event is missing anyway.
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}, r.Queryer); err != nil {
|
||||||
continue nextAuthEvent
|
continue nextAuthEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -599,7 +599,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
|
||||||
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}, t.inputer.Queryer); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
missingEvents = append(missingEvents, t.cacheAndReturn(ev))
|
missingEvents = append(missingEvents, t.cacheAndReturn(ev))
|
||||||
|
@ -690,7 +690,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
|
||||||
AuthEvents: state.GetAuthEvents(),
|
AuthEvents: state.GetAuthEvents(),
|
||||||
}, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
}, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
}, t.inputer.Queryer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -946,7 +946,7 @@ serverLoop:
|
||||||
}
|
}
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}, t.inputer.Queryer); err != nil {
|
||||||
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
||||||
return nil, verifySigError{event.EventID(), err}
|
return nil, verifySigError{event.EventID(), err}
|
||||||
}
|
}
|
||||||
|
|
|
@ -270,7 +270,7 @@ func (r *Admin) PerformAdminDownloadState(
|
||||||
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}, r.Queryer); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
authEventMap[authEvent.EventID()] = authEvent
|
authEventMap[authEvent.EventID()] = authEvent
|
||||||
|
@ -278,7 +278,7 @@ func (r *Admin) PerformAdminDownloadState(
|
||||||
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}, r.Queryer); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
stateEventMap[stateEvent.EventID()] = stateEvent
|
stateEventMap[stateEvent.EventID()] = stateEvent
|
||||||
|
|
|
@ -266,6 +266,7 @@ type backfillRequester struct {
|
||||||
eventIDMap map[string]gomatrixserverlib.PDU
|
eventIDMap map[string]gomatrixserverlib.PDU
|
||||||
historyVisiblity gomatrixserverlib.HistoryVisibility
|
historyVisiblity gomatrixserverlib.HistoryVisibility
|
||||||
roomVersion gomatrixserverlib.RoomVersion
|
roomVersion gomatrixserverlib.RoomVersion
|
||||||
|
membershipQuerier gomatrixserverlib.MembershipQuerier
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBackfillRequester(
|
func newBackfillRequester(
|
||||||
|
@ -292,9 +293,14 @@ func newBackfillRequester(
|
||||||
preferServer: preferServer,
|
preferServer: preferServer,
|
||||||
historyVisiblity: gomatrixserverlib.HistoryVisibilityShared,
|
historyVisiblity: gomatrixserverlib.HistoryVisibilityShared,
|
||||||
roomVersion: roomVersion,
|
roomVersion: roomVersion,
|
||||||
|
membershipQuerier: fsAPI,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backfillRequester) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
|
||||||
|
return b.fsAPI.CurrentMembership(ctx, roomID, senderID)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) {
|
func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) {
|
||||||
b.eventIDMap[targetEvent.EventID()] = targetEvent
|
b.eventIDMap[targetEvent.EventID()] = targetEvent
|
||||||
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {
|
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {
|
||||||
|
@ -647,7 +653,7 @@ func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySe
|
||||||
|
|
||||||
resolver := state.NewStateResolution(db, roomInfo, querier)
|
resolver := state.NewStateResolution(db, roomInfo, querier)
|
||||||
|
|
||||||
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver, querier)
|
_, redactedEvent, _, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver, querier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -1052,3 +1052,14 @@ func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID,
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
|
||||||
|
res := api.QueryMembershipForUserResponse{}
|
||||||
|
err := r.QueryMembershipForSenderID(ctx, roomID, senderID, &res)
|
||||||
|
|
||||||
|
membership := ""
|
||||||
|
if err == nil {
|
||||||
|
membership = res.Membership
|
||||||
|
}
|
||||||
|
return membership, err
|
||||||
|
}
|
||||||
|
|
|
@ -575,7 +575,7 @@ func TestRedaction(t *testing.T) {
|
||||||
err = updater.Commit()
|
err = updater.Commit()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver, &FakeQuerier{})
|
_, redactedEvent, _, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver, &FakeQuerier{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
if redactedEvent != nil {
|
if redactedEvent != nil {
|
||||||
assert.Equal(t, ev.Redacts(), redactedEvent.EventID())
|
assert.Equal(t, ev.Redacts(), redactedEvent.EventID())
|
||||||
|
|
|
@ -190,9 +190,9 @@ type Database interface {
|
||||||
GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
|
GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
|
||||||
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
||||||
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
||||||
MaybeRedactEvent(
|
// MaybeRedactEvent returns the redaction event, the redacted event and the event before redaction if this call resulted in a redaction, else an error
|
||||||
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI,
|
// (nil if there was nothing to do)
|
||||||
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
|
MaybeRedactEvent(ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRoomKeys interface {
|
type UserRoomKeys interface {
|
||||||
|
@ -249,10 +249,8 @@ type EventDatabase interface {
|
||||||
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||||
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
||||||
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
|
// MaybeRedactEvent returns the redaction event, the redacted event and the event before redaction if this call resulted in a redaction, else an error
|
||||||
// (nil if there was nothing to do)
|
// (nil if there was nothing to do)
|
||||||
MaybeRedactEvent(
|
MaybeRedactEvent(ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI) (redactionEvent, redactedEvent, originalRedactedEvent gomatrixserverlib.PDU, err error)
|
||||||
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI,
|
|
||||||
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
|
|
||||||
StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
|
StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,7 +128,7 @@ const deleteMembershipSQL = "" +
|
||||||
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
|
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
const selectRoomsWithMembershipSQL = "" +
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = ANY($2) and forgotten = false"
|
||||||
|
|
||||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||||
// joined to. Since this information is used to populate the user directory, we will
|
// joined to. Since this information is used to populate the user directory, we will
|
||||||
|
@ -347,10 +347,10 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
|
|
||||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
) ([]types.RoomNID, error) {
|
) ([]types.RoomNID, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
rows, err := stmt.QueryContext(ctx, membershipState, pq.Array(userIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
|
||||||
|
|
||||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
|
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
|
||||||
|
|
||||||
|
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
|
||||||
|
|
||||||
type userRoomKeysStatements struct {
|
type userRoomKeysStatements struct {
|
||||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||||
insertUserRoomPublicKeyStmt *sql.Stmt
|
insertUserRoomPublicKeyStmt *sql.Stmt
|
||||||
selectUserRoomKeyStmt *sql.Stmt
|
selectUserRoomKeyStmt *sql.Stmt
|
||||||
selectUserRoomPublicKeyStmt *sql.Stmt
|
selectUserRoomPublicKeyStmt *sql.Stmt
|
||||||
selectUserNIDsStmt *sql.Stmt
|
selectUserNIDsStmt *sql.Stmt
|
||||||
|
selectUserRoomKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateUserRoomKeysTable(db *sql.DB) error {
|
func CreateUserRoomKeysTable(db *sql.DB) error {
|
||||||
|
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
||||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||||
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
||||||
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
|
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
|
||||||
|
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,3 +154,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
|
||||||
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, userNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
|
||||||
|
|
||||||
|
var result []ed25519.PrivateKey
|
||||||
|
var pk ed25519.PrivateKey
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&pk); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -990,15 +990,15 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) (
|
||||||
// to cross-reference with other tables when loading.
|
// to cross-reference with other tables when loading.
|
||||||
//
|
//
|
||||||
// Returns the redaction event and the redacted event if this call resulted in a redaction.
|
// Returns the redaction event and the redacted event if this call resulted in a redaction.
|
||||||
|
// nolint: gocylo
|
||||||
func (d *EventDatabase) MaybeRedactEvent(
|
func (d *EventDatabase) MaybeRedactEvent(
|
||||||
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
|
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
|
||||||
querier api.QuerySenderIDAPI,
|
querier api.QuerySenderIDAPI,
|
||||||
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) {
|
) (redactionEvent, redactedEvent, originalRedactedEvent gomatrixserverlib.PDU, err error) {
|
||||||
var (
|
var (
|
||||||
redactionEvent, redactedEvent *types.Event
|
redactionEv, redactedEv *types.Event
|
||||||
err error
|
validated bool
|
||||||
validated bool
|
ignoreRedaction bool
|
||||||
ignoreRedaction bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
@ -1019,42 +1019,42 @@ func (d *EventDatabase) MaybeRedactEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event)
|
redactionEv, redactedEv, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event)
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return fmt.Errorf("d.loadRedactionPair: %w", err)
|
return fmt.Errorf("d.loadRedactionPair: %w", err)
|
||||||
case validated || redactedEvent == nil || redactionEvent == nil:
|
case validated || redactedEv == nil || redactionEv == nil:
|
||||||
// we've seen this redaction before or there is nothing to redact
|
// we've seen this redaction before or there is nothing to redact
|
||||||
return nil
|
return nil
|
||||||
case redactedEvent.RoomID() != redactionEvent.RoomID():
|
case redactedEv.RoomID() != redactionEv.RoomID():
|
||||||
// redactions across rooms aren't allowed
|
// redactions across rooms aren't allowed
|
||||||
ignoreRedaction = true
|
ignoreRedaction = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var validRoomID *spec.RoomID
|
var validRoomID *spec.RoomID
|
||||||
validRoomID, err = spec.NewRoomID(redactedEvent.RoomID())
|
validRoomID, err = spec.NewRoomID(redactedEv.RoomID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sender1Domain := ""
|
sender1Domain := ""
|
||||||
sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEvent.SenderID())
|
sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEv.SenderID())
|
||||||
if err1 == nil {
|
if err1 == nil {
|
||||||
sender1Domain = string(sender1.Domain())
|
sender1Domain = string(sender1.Domain())
|
||||||
}
|
}
|
||||||
sender2Domain := ""
|
sender2Domain := ""
|
||||||
sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID())
|
sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEv.SenderID())
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
sender2Domain = string(sender2.Domain())
|
sender2Domain = string(sender2.Domain())
|
||||||
}
|
}
|
||||||
var powerlevels *gomatrixserverlib.PowerLevelContent
|
var powerlevels *gomatrixserverlib.PowerLevelContent
|
||||||
powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID())
|
powerlevels, err = plResolver.Resolve(ctx, redactionEv.EventID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact:
|
case powerlevels.UserLevel(redactionEv.SenderID()) >= powerlevels.Redact:
|
||||||
// 1. The power level of the redaction event’s sender is greater than or equal to the redact level.
|
// 1. The power level of the redaction event’s sender is greater than or equal to the redact level.
|
||||||
case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain:
|
case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain:
|
||||||
// 2. The domain of the redaction event’s sender matches that of the original event’s sender.
|
// 2. The domain of the redaction event’s sender matches that of the original event’s sender.
|
||||||
|
@ -1065,42 +1065,47 @@ func (d *EventDatabase) MaybeRedactEvent(
|
||||||
|
|
||||||
// mark the event as redacted
|
// mark the event as redacted
|
||||||
if redactionsArePermanent {
|
if redactionsArePermanent {
|
||||||
redactedEvent.Redact()
|
originalRedactedEvent, err = gomatrixserverlib.MustGetRoomVersion(redactedEv.Version()).
|
||||||
|
NewEventFromTrustedJSON(redactedEv.JSON(), false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
redactedEv.Redact()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
|
err = redactedEv.SetUnsignedField("redacted_because", redactionEv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
|
return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
|
||||||
}
|
}
|
||||||
// NOTSPEC: sytest relies on this unspecced field existing :(
|
// NOTSPEC: sytest relies on this unspecced field existing :(
|
||||||
err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID())
|
err = redactedEv.SetUnsignedField("redacted_by", redactionEv.EventID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
|
return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
|
||||||
}
|
}
|
||||||
// overwrite the eventJSON table
|
// overwrite the eventJSON table
|
||||||
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
|
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEv.EventNID, redactedEv.JSON())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
|
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
|
err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEv.EventID(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
|
return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We remove the entry from the cache, as if we just "StoreRoomServerEvent", we can't be
|
// We remove the entry from the cache, as if we just "StoreRoomServerEvent", we can't be
|
||||||
// certain that the cached entry actually is updated, since ristretto is eventual-persistent.
|
// certain that the cached entry actually is updated, since ristretto is eventual-persistent.
|
||||||
d.Cache.InvalidateRoomServerEvent(redactedEvent.EventNID)
|
d.Cache.InvalidateRoomServerEvent(redactedEv.EventNID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if wErr != nil {
|
if wErr != nil {
|
||||||
return nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
if ignoreRedaction || redactionEvent == nil || redactedEvent == nil {
|
if ignoreRedaction || redactionEv == nil || redactedEv == nil {
|
||||||
return nil, nil, nil
|
return nil, nil, nil, nil
|
||||||
}
|
}
|
||||||
return redactionEvent.PDU, redactedEvent.PDU, nil
|
return redactionEv.PDU, redactedEv.PDU, originalRedactedEvent, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadRedactionPair returns both the redaction event and the redacted event, else nil.
|
// loadRedactionPair returns both the redaction event and the redacted event, else nil.
|
||||||
|
@ -1361,14 +1366,38 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
||||||
}
|
}
|
||||||
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
|
stateKeyNID, err := d.EventStateKeyNIDs(ctx, []string{userID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
||||||
}
|
}
|
||||||
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
|
|
||||||
|
// get the pseudo IDs, if any, as otherwise we don't get the correct room list
|
||||||
|
pseudoIDKeys, err := d.UserRoomKeyTable.SelectPrivateKeysForUserNID(ctx, nil, stateKeyNID[userID])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectPrivateKeysForUserNID: %w", err)
|
||||||
|
}
|
||||||
|
senderIDs := make([]string, len(pseudoIDKeys))
|
||||||
|
var senderID spec.SenderID
|
||||||
|
for _, key := range pseudoIDKeys {
|
||||||
|
senderID = spec.SenderIDFromPseudoIDKey(key)
|
||||||
|
senderIDs = append(senderIDs, string(senderID))
|
||||||
|
}
|
||||||
|
|
||||||
|
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, senderIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: failed to EventStateKeyNIDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)+1)
|
||||||
|
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID[userID])
|
||||||
|
for _, stateKeyNID := range stateKeyNIDMap {
|
||||||
|
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNIDs, membershipState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
|
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,7 +100,7 @@ const updateMembershipForgetRoom = "" +
|
||||||
" WHERE room_nid = $2 AND target_nid = $3"
|
" WHERE room_nid = $2 AND target_nid = $3"
|
||||||
|
|
||||||
const selectRoomsWithMembershipSQL = "" +
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid IN ($2) and forgotten = false"
|
||||||
|
|
||||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||||
// joined to. Since this information is used to populate the user directory, we will
|
// joined to. Since this information is used to populate the user directory, we will
|
||||||
|
@ -297,10 +297,28 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
) ([]types.RoomNID, error) {
|
) ([]types.RoomNID, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
|
||||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
query := strings.Replace(selectRoomsWithMembershipSQL, "($2)", sqlutil.QueryVariadicOffset(len(userIDs), 1), 1)
|
||||||
|
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
stmt, err = txn.PrepareContext(ctx, query)
|
||||||
|
} else {
|
||||||
|
stmt, err = s.db.Prepare(query)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, stmt, "SelectRoomsWithMembership: stmt.close() failed")
|
||||||
|
params := make([]any, len(userIDs)+1)
|
||||||
|
params[0] = membershipState
|
||||||
|
for i, userID := range userIDs {
|
||||||
|
params[i+1] = userID
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,6 +56,8 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
|
||||||
|
|
||||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
|
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
|
||||||
|
|
||||||
|
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
|
||||||
|
|
||||||
type userRoomKeysStatements struct {
|
type userRoomKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||||
|
@ -63,6 +65,7 @@ type userRoomKeysStatements struct {
|
||||||
selectUserRoomKeyStmt *sql.Stmt
|
selectUserRoomKeyStmt *sql.Stmt
|
||||||
selectUserRoomPublicKeyStmt *sql.Stmt
|
selectUserRoomPublicKeyStmt *sql.Stmt
|
||||||
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
|
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
|
||||||
|
selectUserRoomKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateUserRoomKeysTable(db *sql.DB) error {
|
func CreateUserRoomKeysTable(db *sql.DB) error {
|
||||||
|
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
||||||
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
||||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||||
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
||||||
|
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
|
||||||
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
|
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
@ -165,3 +169,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
|
||||||
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, userNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
|
||||||
|
|
||||||
|
var result []ed25519.PrivateKey
|
||||||
|
var pk ed25519.PrivateKey
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&pk); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -142,7 +142,7 @@ type Membership interface {
|
||||||
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
|
||||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||||
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
||||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
|
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
|
||||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||||
|
@ -198,6 +198,8 @@ type UserRoomKeys interface {
|
||||||
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
|
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
|
||||||
// If a senderKey can't be found, it is omitted in the result.
|
// If a senderKey can't be found, it is omitted in the result.
|
||||||
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
|
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
|
||||||
|
// SelectRoomNIDs selects all roomNIDs for a specific user
|
||||||
|
SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StrippedEvent represents a stripped event for returning extracted content values.
|
// StrippedEvent represents a stripped event for returning extracted content values.
|
||||||
|
|
|
@ -99,12 +99,12 @@ func TestMembershipTable(t *testing.T) {
|
||||||
assert.Equal(t, 10, len(members))
|
assert.Equal(t, 10, len(members))
|
||||||
|
|
||||||
// Get correct user
|
// Get correct user
|
||||||
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, userNIDs[1], tables.MembershipStateLeaveOrBan)
|
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[1]}, tables.MembershipStateLeaveOrBan)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, []types.RoomNID{1}, roomNIDs)
|
assert.Equal(t, []types.RoomNID{1}, roomNIDs)
|
||||||
|
|
||||||
// User is not joined to room
|
// User is not joined to room
|
||||||
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, userNIDs[5], tables.MembershipStateJoin)
|
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[5]}, tables.MembershipStateJoin)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 0, len(roomNIDs))
|
assert.Equal(t, 0, len(roomNIDs))
|
||||||
|
|
||||||
|
|
|
@ -115,6 +115,11 @@ func TestUserRoomKeysTable(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, key4, gotPublicKey)
|
assert.Equal(t, key4, gotPublicKey)
|
||||||
|
|
||||||
|
// query rooms for a specific user
|
||||||
|
var pks []ed25519.PrivateKey
|
||||||
|
pks, err = tab.SelectPrivateKeysForUserNID(context.Background(), txn, userNID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []ed25519.PrivateKey{key}, pks)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -34,11 +34,13 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/streams"
|
"github.com/matrix-org/dendrite/syncapi/streams"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OutputRoomEventConsumer consumes events that originated in the room server.
|
// OutputRoomEventConsumer consumes events that originated in the room server.
|
||||||
|
@ -592,6 +594,15 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
|
||||||
return event, nil
|
return event, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prevContent := prevEvent.Content()
|
||||||
|
// if we're storing a pseudoID event, make sure to delete the mxid_mapping
|
||||||
|
if event.Type() == spec.MRoomMember && event.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||||
|
prevContent, err = sjson.DeleteBytes(prevEvent.Content(), "mxid_mapping")
|
||||||
|
if err != nil {
|
||||||
|
return event, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
prevEventSender := string(prevEvent.SenderID())
|
prevEventSender := string(prevEvent.SenderID())
|
||||||
prevUser, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, prevEvent.SenderID())
|
prevUser, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, prevEvent.SenderID())
|
||||||
if err == nil && prevUser != nil {
|
if err == nil && prevUser != nil {
|
||||||
|
@ -599,7 +610,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
|
||||||
}
|
}
|
||||||
|
|
||||||
prev := types.PrevEventRef{
|
prev := types.PrevEventRef{
|
||||||
PrevContent: prevEvent.Content(),
|
PrevContent: prevContent,
|
||||||
ReplacesState: prevEvent.EventID(),
|
ReplacesState: prevEvent.EventID(),
|
||||||
PrevSenderID: prevEventSender,
|
PrevSenderID: prevEventSender,
|
||||||
}
|
}
|
||||||
|
|
|
@ -381,11 +381,50 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
|
||||||
|
|
||||||
newEvent := &rstypes.HeaderedEvent{PDU: eventToRedact}
|
newEvent := &rstypes.HeaderedEvent{PDU: eventToRedact}
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
// if we are redacting a state event, also update the current_room_state table
|
||||||
|
if newEvent.StateKey() != nil {
|
||||||
|
if err = d.redactCurrentStateEvent(ctx, txn, newEvent, querier); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
|
return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// redactCurrentStateEvent updates the JSON data in the current_room_state table
|
||||||
|
func (d *Database) redactCurrentStateEvent(ctx context.Context, txn *sql.Tx, newEvent *rstypes.HeaderedEvent, querier api.QuerySenderIDAPI) error {
|
||||||
|
// resolve the state key, which may be user pseudoID
|
||||||
|
if *newEvent.StateKey() != "" {
|
||||||
|
validRoomID, err := spec.NewRoomID(newEvent.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var sku *spec.UserID
|
||||||
|
stateKey := newEvent.StateKey()
|
||||||
|
sku, err = querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
|
||||||
|
if err == nil && sku != nil {
|
||||||
|
sKey := sku.String()
|
||||||
|
newEvent.StateKeyResolved = &sKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the current stream position of the event
|
||||||
|
streamEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, []string{newEvent.EventID()})
|
||||||
|
if err == nil && len(streamEvents) > 0 {
|
||||||
|
var membershipPtr *string
|
||||||
|
var membership string
|
||||||
|
membership, err = streamEvents[0].Membership()
|
||||||
|
if err == nil {
|
||||||
|
membershipPtr = &membership
|
||||||
|
}
|
||||||
|
if err = d.CurrentRoomState.UpsertRoomState(ctx, txn, newEvent, membershipPtr, streamEvents[0].StreamPosition); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
|
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
|
||||||
// Returns a map of room ID to list of events.
|
// Returns a map of room ID to list of events.
|
||||||
func (d *Database) fetchStateEvents(
|
func (d *Database) fetchStateEvents(
|
||||||
|
|
|
@ -1030,7 +1030,7 @@ func TestRedaction(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
|
|
||||||
redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
|
redactedEvent := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "join", "displayname": "alice"}, test.WithStateKey(alice.ID))
|
||||||
redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()})
|
redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()})
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := MustCreateDatabase(t, dbType)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
|
@ -1064,5 +1064,25 @@ func TestRedaction(t *testing.T) {
|
||||||
if depth.Exists() {
|
if depth.Exists() {
|
||||||
t.Error("unexpected auth_events in redacted event")
|
t.Error("unexpected auth_events in redacted event")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dbTxn, err := db.NewDatabaseTransaction(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
filter := synctypes.DefaultStateFilter()
|
||||||
|
wantTypes := []string{spec.MRoomMember}
|
||||||
|
filter.Types = &wantTypes
|
||||||
|
evs, err = dbTxn.CurrentRoomState.SelectCurrentState(context.Background(), nil, redactedEvent.RoomID(), &filter, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if count := len(evs); count != 1 {
|
||||||
|
t.Fatalf("expected 1 event, got %d", count)
|
||||||
|
}
|
||||||
|
// we expect that the displayname does not exist anymore
|
||||||
|
displayname := gjson.GetBytes(evs[0].Content(), "displayname")
|
||||||
|
if displayname.Exists() {
|
||||||
|
t.Fatal("expected displayname to be redacted, but wasn't")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue