Pass in new "IsEventRejected" parameter

This commit is contained in:
Till Faelligen 2023-10-24 09:04:39 +02:00
parent 8b3adaf244
commit f70996d8d3
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
5 changed files with 46 additions and 0 deletions

View file

@ -202,12 +202,25 @@ func main() {
authEvents[i] = authEventEntries[i].PDU authEvents[i] = authEventEntries[i].PDU
} }
// Get the roomNID
roomInfo, err = roomserverDB.RoomInfo(ctx, authEvents[0].RoomID().String())
if err != nil {
panic(err)
}
fmt.Println("Resolving state") fmt.Println("Resolving state")
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, err := roomserverDB.IsEventRejected(ctx, roomInfo.RoomNID, eventID)
if err != nil {
return true
}
return isRejected
},
) )
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -498,6 +498,13 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, err := t.db.IsEventRejected(ctx, t.roomInfo.RoomNID, eventID)
if err != nil {
return true
}
return isRejected
},
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -165,6 +165,13 @@ func (r *Queryer) QueryStateAfterEvents(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, err := r.DB.IsEventRejected(ctx, info.RoomNID, eventID)
if err != nil {
return true
}
return isRejected
},
) )
if err != nil { if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
@ -676,6 +683,13 @@ func (r *Queryer) QueryStateAndAuthChain(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, err := r.DB.IsEventRejected(ctx, info.RoomNID, eventID)
if err != nil {
return true
}
return isRejected
},
) )
if err != nil { if err != nil {
return err return err

View file

@ -45,6 +45,7 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (bool, error)
} }
type StateResolution struct { type StateResolution struct {
@ -1066,6 +1067,13 @@ func (v *StateResolution) resolveConflictsV2(
func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}, },
func(eventID string) bool {
isRejected, err := v.db.IsEventRejected(ctx, v.roomInfo.RoomNID, eventID)
if err != nil {
return true
}
return isRejected
},
) )
}() }()

View file

@ -250,3 +250,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }
func (u *RoomUpdater) IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (bool, error) {
return u.d.IsEventRejected(ctx, roomNID, eventID)
}