mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-12 00:23:10 -06:00
Ensure all senderid/userid conversions go through rsAPI
This commit is contained in:
parent
77d9e4e93d
commit
83e2cb42c8
|
|
@ -11,11 +11,13 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
|
@ -66,10 +68,15 @@ func main() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
natsInstance := &jetstream.NATSInstance{}
|
||||
_, _ = natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm,
|
||||
natsInstance, caching.NewRistrettoCache(128*1024*1024, time.Hour, true), false)
|
||||
|
||||
roomInfo := &types.RoomInfo{
|
||||
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
|
||||
}
|
||||
stateres := state.NewStateResolution(roomserverDB, roomInfo)
|
||||
stateres := state.NewStateResolution(roomserverDB, roomInfo, rsAPI)
|
||||
|
||||
if *difference {
|
||||
if len(snapshotNIDs) != 2 {
|
||||
|
|
@ -184,7 +191,7 @@ func main() {
|
|||
var resolved Events
|
||||
resolved, err = gomatrixserverlib.ResolveConflicts(
|
||||
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return roomserverDB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ package auth
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
|
@ -25,7 +25,7 @@ import (
|
|||
// IsServerAllowed returns true if the server is allowed to see events in the room
|
||||
// at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87
|
||||
func IsServerAllowed(
|
||||
ctx context.Context, db storage.RoomDatabase,
|
||||
ctx context.Context, querier api.QuerySenderIDAPI,
|
||||
serverName spec.ServerName,
|
||||
serverCurrentlyInRoom bool,
|
||||
authEvents []gomatrixserverlib.PDU,
|
||||
|
|
@ -41,7 +41,7 @@ func IsServerAllowed(
|
|||
return true
|
||||
}
|
||||
// 2. If the user's membership was join, allow.
|
||||
joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join)
|
||||
joinedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Join)
|
||||
if joinedUserExists {
|
||||
return true
|
||||
}
|
||||
|
|
@ -50,7 +50,7 @@ func IsServerAllowed(
|
|||
return true
|
||||
}
|
||||
// 4. If the user's membership was invite, and the history_visibility was set to invited, allow.
|
||||
invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite)
|
||||
invitedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Invite)
|
||||
if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited {
|
||||
return true
|
||||
}
|
||||
|
|
@ -74,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver
|
|||
return visibility
|
||||
}
|
||||
|
||||
func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
|
||||
func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
|
||||
for _, ev := range authEvents {
|
||||
if ev.Type() != spec.MRoomMember {
|
||||
continue
|
||||
|
|
@ -89,7 +89,7 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabas
|
|||
continue
|
||||
}
|
||||
|
||||
userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
|
||||
userID, err := querier.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,17 +4,17 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
type FakeStorageDB struct {
|
||||
storage.RoomDatabase
|
||||
type FakeQuerier struct {
|
||||
api.QuerySenderIDAPI
|
||||
}
|
||||
|
||||
func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return spec.NewUserID(string(senderID), true)
|
||||
}
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) {
|
|||
authEvents = append(authEvents, ev.PDU)
|
||||
}
|
||||
|
||||
if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
|
||||
if got := IsServerAllowed(context.Background(), &FakeQuerier{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
|
||||
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
|
|||
}
|
||||
|
||||
stateRes := &api.QueryLatestEventsAndStateResponse{}
|
||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
|
||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||
IsLocalServerName: r.Cfg.Global.IsLocalServerName,
|
||||
DB: r.DB,
|
||||
FSAPI: r.fsAPI,
|
||||
Querier: r.Queryer,
|
||||
KeyRing: r.KeyRing,
|
||||
// Perspective servers are trusted to not lie about server keys, so we will also
|
||||
// prefer these servers when backfilling (assuming they are in the room) rather
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
|
|
@ -36,6 +37,7 @@ func CheckForSoftFail(
|
|||
roomInfo *types.RoomInfo,
|
||||
event *types.HeaderedEvent,
|
||||
stateEventIDs []string,
|
||||
querier api.QuerySenderIDAPI,
|
||||
) (bool, error) {
|
||||
rewritesState := len(stateEventIDs) > 1
|
||||
|
||||
|
|
@ -49,7 +51,7 @@ func CheckForSoftFail(
|
|||
} else {
|
||||
// Then get the state entries for the current state snapshot.
|
||||
// We'll use this to check if the event is allowed right now.
|
||||
roomState := state.NewStateResolution(db, roomInfo)
|
||||
roomState := state.NewStateResolution(db, roomInfo, querier)
|
||||
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
|
||||
|
|
@ -77,7 +79,7 @@ func CheckForSoftFail(
|
|||
|
||||
// Check if the event is allowed.
|
||||
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return db.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
// return true, nil
|
||||
return true, err
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ func UpdateToInviteMembership(
|
|||
// memberships. If the servername is not supplied then the local server will be
|
||||
// checked instead using a faster code path.
|
||||
// TODO: This should probably be replaced by an API call.
|
||||
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) {
|
||||
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) {
|
||||
info, err := db.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
@ -94,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
|
|||
for i := range events {
|
||||
gmslEvents[i] = events[i].PDU
|
||||
}
|
||||
return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil
|
||||
return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil
|
||||
}
|
||||
|
||||
func IsInvitePending(
|
||||
|
|
@ -211,8 +211,8 @@ func GetMembershipsAtState(
|
|||
return events, nil
|
||||
}
|
||||
|
||||
func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
|
||||
roomState := state.NewStateResolution(db, info)
|
||||
func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) {
|
||||
roomState := state.NewStateResolution(db, info, querier)
|
||||
// Lookup the event NID
|
||||
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
||||
if err != nil {
|
||||
|
|
@ -229,8 +229,8 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
|
|||
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
||||
}
|
||||
|
||||
func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
|
||||
roomState := state.NewStateResolution(db, info)
|
||||
func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) {
|
||||
roomState := state.NewStateResolution(db, info, querier)
|
||||
// Fetch the state as it was when this event was fired
|
||||
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
|
||||
}
|
||||
|
|
@ -264,7 +264,7 @@ func LoadStateEvents(
|
|||
}
|
||||
|
||||
func CheckServerAllowedToSeeEvent(
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI,
|
||||
) (bool, error) {
|
||||
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
|
||||
switch err {
|
||||
|
|
@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
|
|||
case tables.OptimisationNotSupportedError:
|
||||
// The database engine didn't support this optimisation, so fall back to using
|
||||
// the old and slow method
|
||||
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
|
||||
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -288,13 +288,13 @@ func CheckServerAllowedToSeeEvent(
|
|||
return false, err
|
||||
}
|
||||
}
|
||||
return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil
|
||||
return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil
|
||||
}
|
||||
|
||||
func slowGetHistoryVisibilityState(
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI,
|
||||
) ([]gomatrixserverlib.PDU, error) {
|
||||
roomState := state.NewStateResolution(db, info)
|
||||
roomState := state.NewStateResolution(db, info, querier)
|
||||
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
|
|
@ -320,7 +320,7 @@ func slowGetHistoryVisibilityState(
|
|||
// are "" since these will contain history visibility etc.
|
||||
for nid, key := range stateKeys {
|
||||
if key != "" {
|
||||
userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
|
||||
userID, err := querier.QueryUserIDForSender(ctx, roomID, spec.SenderID(key))
|
||||
if err == nil && userID != nil {
|
||||
if userID.Domain() != serverName {
|
||||
delete(stateKeys, nid)
|
||||
|
|
@ -349,7 +349,7 @@ func slowGetHistoryVisibilityState(
|
|||
// TODO: Remove this when we have tests to assert correctness of this function
|
||||
func ScanEventTree(
|
||||
ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
|
||||
serverName spec.ServerName,
|
||||
serverName spec.ServerName, querier api.QuerySenderIDAPI,
|
||||
) ([]types.EventNID, map[string]struct{}, error) {
|
||||
var resultNIDs []types.EventNID
|
||||
var err error
|
||||
|
|
@ -392,7 +392,7 @@ BFSLoop:
|
|||
// It's nasty that we have to extract the room ID from an event, but many federation requests
|
||||
// only talk in event IDs, no room IDs at all (!!!)
|
||||
ev := events[0]
|
||||
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
|
||||
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID())
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
|
||||
}
|
||||
|
|
@ -415,7 +415,7 @@ BFSLoop:
|
|||
// hasn't been seen before.
|
||||
if !visited[pre] {
|
||||
visited[pre] = true
|
||||
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
|
||||
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
||||
"Error checking if allowed to see event",
|
||||
|
|
@ -444,7 +444,7 @@ BFSLoop:
|
|||
}
|
||||
|
||||
func QueryLatestEventsAndState(
|
||||
ctx context.Context, db storage.Database,
|
||||
ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI,
|
||||
request *api.QueryLatestEventsAndStateRequest,
|
||||
response *api.QueryLatestEventsAndStateResponse,
|
||||
) error {
|
||||
|
|
@ -457,7 +457,7 @@ func QueryLatestEventsAndState(
|
|||
return nil
|
||||
}
|
||||
|
||||
roomState := state.NewStateResolution(db, roomInfo)
|
||||
roomState := state.NewStateResolution(db, roomInfo, querier)
|
||||
response.RoomExists = true
|
||||
response.RoomVersion = roomInfo.RoomVersion
|
||||
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ func (r *Inputer) processRoomEvent(
|
|||
if roomInfo == nil && !isCreateEvent {
|
||||
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
|
||||
}
|
||||
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
sender, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
|
||||
}
|
||||
|
|
@ -283,7 +283,7 @@ func (r *Inputer) processRoomEvent(
|
|||
// Check if the event is allowed by its auth events. If it isn't then
|
||||
// we consider the event to be "rejected" — it will still be persisted.
|
||||
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
isRejected = true
|
||||
rejectionErr = err
|
||||
|
|
@ -321,7 +321,7 @@ func (r *Inputer) processRoomEvent(
|
|||
if input.Kind == api.KindNew && !isCreateEvent {
|
||||
// Check that the event passes authentication checks based on the
|
||||
// current room state.
|
||||
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs)
|
||||
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs, r.Queryer)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Error authing soft-failed event")
|
||||
}
|
||||
|
|
@ -401,7 +401,7 @@ func (r *Inputer) processRoomEvent(
|
|||
redactedEvent gomatrixserverlib.PDU
|
||||
)
|
||||
if !isRejected && !isCreateEvent {
|
||||
resolver := state.NewStateResolution(r.DB, roomInfo)
|
||||
resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer)
|
||||
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -588,7 +588,7 @@ func (r *Inputer) processStateBefore(
|
|||
gomatrixserverlib.ToPDUs(stateBeforeEvent),
|
||||
)
|
||||
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); rejectionErr != nil {
|
||||
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
|
||||
return
|
||||
|
|
@ -701,7 +701,7 @@ nextAuthEvent:
|
|||
// skip it, because gomatrixserverlib.Allowed() will notice a problem
|
||||
// if a critical event is missing anyway.
|
||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
continue nextAuthEvent
|
||||
}
|
||||
|
|
@ -719,7 +719,7 @@ nextAuthEvent:
|
|||
|
||||
// Check if the auth event should be rejected.
|
||||
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
})
|
||||
if isRejected = err != nil; isRejected {
|
||||
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
|
||||
|
|
@ -783,7 +783,7 @@ func (r *Inputer) calculateAndSetState(
|
|||
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
|
||||
}
|
||||
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
||||
roomState := state.NewStateResolution(updater, roomInfo)
|
||||
roomState := state.NewStateResolution(updater, roomInfo, r.Queryer)
|
||||
|
||||
if input.HasState {
|
||||
// We've been told what the state at the event is so we don't need to calculate it.
|
||||
|
|
|
|||
|
|
@ -213,7 +213,7 @@ func (u *latestEventsUpdater) latestState() error {
|
|||
defer trace.EndRegion()
|
||||
|
||||
var err error
|
||||
roomState := state.NewStateResolution(u.updater, u.roomInfo)
|
||||
roomState := state.NewStateResolution(u.updater, u.roomInfo, u.api.Queryer)
|
||||
|
||||
// Work out if the state at the extremities has actually changed
|
||||
// or not. If they haven't then we won't bother doing all of the
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
|
|||
defer trace.EndRegion()
|
||||
|
||||
var res parsedRespState
|
||||
roomState := state.NewStateResolution(t.db, t.roomInfo)
|
||||
roomState := state.NewStateResolution(t.db, t.roomInfo, t.inputer.Queryer)
|
||||
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
|
||||
if err != nil {
|
||||
t.log.WithError(err).Warnf("failed to get state after %s locally", eventID)
|
||||
|
|
@ -474,7 +474,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
|
|||
}
|
||||
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
|
||||
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -483,7 +483,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
|
|||
// apply the current event
|
||||
retryAllowedState:
|
||||
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
switch missing := err.(type) {
|
||||
case gomatrixserverlib.MissingAuthEventError:
|
||||
|
|
@ -570,7 +570,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
|
|||
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
|
||||
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -661,7 +661,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
|
|||
StateEvents: state.GetStateEvents(),
|
||||
AuthEvents: state.GetAuthEvents(),
|
||||
}, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -898,7 +898,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
|
|||
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
|
||||
}
|
||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
||||
return nil, verifySigError{event.EventID(), err}
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ func (r *Admin) PerformAdminDownloadState(
|
|||
}
|
||||
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -273,7 +273,7 @@ func (r *Admin) PerformAdminDownloadState(
|
|||
}
|
||||
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ type Backfiller struct {
|
|||
DB storage.Database
|
||||
FSAPI federationAPI.RoomserverFederationAPI
|
||||
KeyRing gomatrixserverlib.JSONVerifier
|
||||
Querier api.QuerySenderIDAPI
|
||||
|
||||
// The servers which should be preferred above other servers when backfilling
|
||||
PreferServers []spec.ServerName
|
||||
|
|
@ -79,7 +80,7 @@ func (r *Backfiller) PerformBackfill(
|
|||
}
|
||||
|
||||
// Scan the event tree for events to send back.
|
||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
|
||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r.Querier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -113,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||
if info == nil || info.IsStub() {
|
||||
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
|
||||
}
|
||||
requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
|
||||
requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
|
||||
// Request 100 items regardless of what the query asks for.
|
||||
// We don't want to go much higher than this.
|
||||
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
|
||||
|
|
@ -122,7 +123,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||
events, err := gomatrixserverlib.RequestBackfill(
|
||||
ctx, req.VirtualHost, requester,
|
||||
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
)
|
||||
// Only return an error if we really couldn't get any events.
|
||||
|
|
@ -135,7 +136,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
|
||||
|
||||
// persist these new events - auth checks have already been done
|
||||
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
|
||||
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, r.Querier, events)
|
||||
|
||||
for _, ev := range backfilledEventMap {
|
||||
// now add state for these events
|
||||
|
|
@ -213,7 +214,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
|
|||
}
|
||||
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
||||
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
})
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("failed to load and verify event")
|
||||
|
|
@ -246,13 +247,14 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
|
|||
}
|
||||
}
|
||||
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
|
||||
persistEvents(ctx, r.DB, newEvents)
|
||||
persistEvents(ctx, r.DB, r.Querier, newEvents)
|
||||
}
|
||||
|
||||
// backfillRequester implements gomatrixserverlib.BackfillRequester
|
||||
type backfillRequester struct {
|
||||
db storage.Database
|
||||
fsAPI federationAPI.RoomserverFederationAPI
|
||||
querier api.QuerySenderIDAPI
|
||||
virtualHost spec.ServerName
|
||||
isLocalServerName func(spec.ServerName) bool
|
||||
preferServer map[spec.ServerName]bool
|
||||
|
|
@ -268,6 +270,7 @@ type backfillRequester struct {
|
|||
|
||||
func newBackfillRequester(
|
||||
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
|
||||
querier api.QuerySenderIDAPI,
|
||||
virtualHost spec.ServerName,
|
||||
isLocalServerName func(spec.ServerName) bool,
|
||||
bwExtrems map[string][]string, preferServers []spec.ServerName,
|
||||
|
|
@ -279,6 +282,7 @@ func newBackfillRequester(
|
|||
return &backfillRequester{
|
||||
db: db,
|
||||
fsAPI: fsAPI,
|
||||
querier: querier,
|
||||
virtualHost: virtualHost,
|
||||
isLocalServerName: isLocalServerName,
|
||||
eventIDToBeforeStateIDs: make(map[string][]string),
|
||||
|
|
@ -460,14 +464,14 @@ FindSuccessor:
|
|||
return nil
|
||||
}
|
||||
|
||||
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID)
|
||||
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID, b.querier)
|
||||
if err != nil {
|
||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
|
||||
return nil
|
||||
}
|
||||
|
||||
// possibly return all joined servers depending on history visiblity
|
||||
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost)
|
||||
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, b.querier, info, stateEntries, b.virtualHost)
|
||||
b.historyVisiblity = visibility
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
||||
|
|
@ -488,7 +492,7 @@ FindSuccessor:
|
|||
// Store the server names in a temporary map to avoid duplicates.
|
||||
serverSet := make(map[spec.ServerName]bool)
|
||||
for _, event := range memberEvents {
|
||||
if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
|
||||
if sender, err := b.querier.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
|
||||
serverSet[sender.Domain()] = true
|
||||
}
|
||||
}
|
||||
|
|
@ -554,7 +558,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
|
|||
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
|
||||
// pull all events and then filter by that table.
|
||||
func joinEventsFromHistoryVisibility(
|
||||
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
|
||||
ctx context.Context, db storage.RoomDatabase, querier api.QuerySenderIDAPI, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
|
||||
thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
|
||||
|
||||
var eventNIDs []types.EventNID
|
||||
|
|
@ -582,7 +586,7 @@ func joinEventsFromHistoryVisibility(
|
|||
}
|
||||
|
||||
// Can we see events in the room?
|
||||
canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events)
|
||||
canSeeEvents := auth.IsServerAllowed(ctx, querier, thisServer, true, events)
|
||||
visibility := auth.HistoryVisibilityForRoom(events)
|
||||
if !canSeeEvents {
|
||||
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
|
||||
|
|
@ -597,7 +601,7 @@ func joinEventsFromHistoryVisibility(
|
|||
return evs, visibility, err
|
||||
}
|
||||
|
||||
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
|
||||
func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
|
||||
var roomNID types.RoomNID
|
||||
var eventNID types.EventNID
|
||||
backfilledEventMap := make(map[string]types.Event)
|
||||
|
|
@ -639,7 +643,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
|
|||
continue
|
||||
}
|
||||
|
||||
resolver := state.NewStateResolution(db, roomInfo)
|
||||
resolver := state.NewStateResolution(db, roomInfo, querier)
|
||||
|
||||
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
}
|
||||
}
|
||||
}
|
||||
senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID)
|
||||
senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
|
||||
return "", &util.JSONResponse{
|
||||
|
|
@ -324,7 +324,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
}
|
||||
|
||||
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return c.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
|
||||
return "", &util.JSONResponse{
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek(
|
|||
response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]}
|
||||
|
||||
// XXX: do we actually need to do a state resolution here?
|
||||
roomState := state.NewStateResolution(r.DB, info)
|
||||
roomState := state.NewStateResolution(r.DB, info, r.Inputer.Queryer)
|
||||
|
||||
var stateEntries []types.StateEntry
|
||||
stateEntries, err = roomState.LoadStateAtSnapshot(
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ import (
|
|||
|
||||
type QueryState struct {
|
||||
storage.Database
|
||||
querier api.QuerySenderIDAPI
|
||||
}
|
||||
|
||||
func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) {
|
||||
|
|
@ -46,7 +47,7 @@ func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWant
|
|||
return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
|
||||
}
|
||||
if info != nil {
|
||||
roomState := state.NewStateResolution(q.Database, info)
|
||||
roomState := state.NewStateResolution(q.Database, info, q.querier)
|
||||
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
|
||||
ctx, info.StateSnapshotNID(), stateWanted,
|
||||
)
|
||||
|
|
@ -126,7 +127,7 @@ func (r *Inviter) PerformInvite(
|
|||
) error {
|
||||
event := req.Event
|
||||
|
||||
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
sender, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
if err != nil {
|
||||
return spec.InvalidParam("The sender user ID is invalid")
|
||||
}
|
||||
|
|
@ -161,9 +162,9 @@ func (r *Inviter) PerformInvite(
|
|||
IsTargetLocal: isTargetLocal,
|
||||
StrippedState: req.InviteRoomState,
|
||||
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
|
||||
StateQuerier: &QueryState{r.DB},
|
||||
StateQuerier: &QueryState{r.DB, r.RSAPI},
|
||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
}
|
||||
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
},
|
||||
}
|
||||
latestRes := api.QueryLatestEventsAndStateResponse{}
|
||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil {
|
||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !latestRes.RoomExists {
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ func (r *Queryer) QueryLatestEventsAndState(
|
|||
request *api.QueryLatestEventsAndStateRequest,
|
||||
response *api.QueryLatestEventsAndStateResponse,
|
||||
) error {
|
||||
return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response)
|
||||
return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response)
|
||||
}
|
||||
|
||||
// QueryStateAfterEvents implements api.RoomserverInternalAPI
|
||||
|
|
@ -106,7 +106,7 @@ func (r *Queryer) QueryStateAfterEvents(
|
|||
return nil
|
||||
}
|
||||
|
||||
roomState := state.NewStateResolution(r.DB, info)
|
||||
roomState := state.NewStateResolution(r.DB, info, r)
|
||||
response.RoomExists = true
|
||||
response.RoomVersion = info.RoomVersion
|
||||
|
||||
|
|
@ -160,7 +160,7 @@ func (r *Queryer) QueryStateAfterEvents(
|
|||
|
||||
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
||||
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -271,7 +271,7 @@ func (r *Queryer) QueryMembershipForUser(
|
|||
request *api.QueryMembershipForUserRequest,
|
||||
response *api.QueryMembershipForUserResponse,
|
||||
) error {
|
||||
senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
|
||||
senderID, err := r.QuerySenderIDForUser(ctx, request.RoomID, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -320,7 +320,7 @@ func (r *Queryer) QueryMembershipAtEvent(
|
|||
}
|
||||
|
||||
response.Membership = make(map[string]*types.HeaderedEvent)
|
||||
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID])
|
||||
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get state before event: %w", err)
|
||||
}
|
||||
|
|
@ -445,7 +445,7 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||
|
||||
events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
|
||||
} else {
|
||||
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
|
||||
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r)
|
||||
if err != nil {
|
||||
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
||||
return err
|
||||
|
|
@ -532,7 +532,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
|
|||
}
|
||||
|
||||
return helpers.CheckServerAllowedToSeeEvent(
|
||||
ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
|
||||
ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -573,7 +573,7 @@ func (r *Queryer) QueryMissingEvents(
|
|||
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
|
||||
}
|
||||
|
||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
|
||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -652,7 +652,7 @@ func (r *Queryer) QueryStateAndAuthChain(
|
|||
if request.ResolveState {
|
||||
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
||||
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -673,7 +673,7 @@ func (r *Queryer) QueryStateAndAuthChain(
|
|||
|
||||
// first bool: is rejected, second bool: state missing
|
||||
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) {
|
||||
roomState := state.NewStateResolution(r.DB, roomInfo)
|
||||
roomState := state.NewStateResolution(r.DB, roomInfo, r)
|
||||
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
|
|
|
|||
|
|
@ -516,6 +516,10 @@ func TestRedaction(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
natsInstance := &jetstream.NATSInstance{}
|
||||
_, _ = natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
authEvents := []types.EventNID{}
|
||||
|
|
@ -551,7 +555,7 @@ func TestRedaction(t *testing.T) {
|
|||
}
|
||||
|
||||
// Calculate the snapshotNID etc.
|
||||
plResolver := state.NewStateResolution(db, roomInfo)
|
||||
plResolver := state.NewStateResolution(db, roomInfo, rsAPI)
|
||||
stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
|
|
@ -44,20 +45,21 @@ type StateResolutionStorage interface {
|
|||
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
||||
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
|
||||
}
|
||||
|
||||
type StateResolution struct {
|
||||
db StateResolutionStorage
|
||||
roomInfo *types.RoomInfo
|
||||
events map[types.EventNID]gomatrixserverlib.PDU
|
||||
Querier api.QuerySenderIDAPI
|
||||
}
|
||||
|
||||
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
|
||||
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo, querier api.QuerySenderIDAPI) StateResolution {
|
||||
return StateResolution{
|
||||
db: db,
|
||||
roomInfo: roomInfo,
|
||||
events: make(map[types.EventNID]gomatrixserverlib.PDU),
|
||||
Querier: querier,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -948,7 +950,7 @@ func (v *StateResolution) resolveConflictsV1(
|
|||
|
||||
// Resolve the conflicts.
|
||||
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return v.db.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
})
|
||||
|
||||
// Map from the full events back to numeric state entries.
|
||||
|
|
@ -1062,7 +1064,7 @@ func (v *StateResolution) resolveConflictsV2(
|
|||
nonConflictedEvents,
|
||||
authEvents,
|
||||
func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return v.db.GetUserIDForSender(ctx, roomID, senderID)
|
||||
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
|
|
|||
|
|
@ -231,7 +231,6 @@ type RoomDatabase interface {
|
|||
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
||||
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
|
||||
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
|
||||
}
|
||||
|
||||
type EventDatabase interface {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
|
@ -251,7 +250,3 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
|||
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return u.d.GetUserIDForSender(ctx, roomID, senderID)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue