428 lines
13 KiB
Go
428 lines
13 KiB
Go
package helpers
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/matrix-org/dendrite/roomserver/api"
|
|
"github.com/matrix-org/dendrite/roomserver/auth"
|
|
"github.com/matrix-org/dendrite/roomserver/state"
|
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
|
"github.com/matrix-org/dendrite/roomserver/types"
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
"github.com/matrix-org/util"
|
|
)
|
|
|
|
// TODO: temporary package which has helper functions used by both internal/perform packages.
|
|
// Move these to a more sensible place.
|
|
|
|
func UpdateToInviteMembership(
|
|
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
|
roomVersion gomatrixserverlib.RoomVersion,
|
|
) ([]api.OutputEvent, error) {
|
|
// We may have already sent the invite to the user, either because we are
|
|
// reprocessing this event, or because the we received this invite from a
|
|
// remote server via the federation invite API. In those cases we don't need
|
|
// to send the event.
|
|
needsSending, err := mu.SetToInvite(add)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if needsSending {
|
|
// We notify the consumers using a special event even though we will
|
|
// notify them about the change in current state as part of the normal
|
|
// room event stream. This ensures that the consumers only have to
|
|
// consider a single stream of events when determining whether a user
|
|
// is invited, rather than having to combine multiple streams themselves.
|
|
onie := api.OutputNewInviteEvent{
|
|
Event: add.Headered(roomVersion),
|
|
RoomVersion: roomVersion,
|
|
}
|
|
updates = append(updates, api.OutputEvent{
|
|
Type: api.OutputTypeNewInviteEvent,
|
|
NewInviteEvent: &onie,
|
|
})
|
|
}
|
|
return updates, nil
|
|
}
|
|
|
|
// IsServerCurrentlyInRoom checks if a server is in a given room, based on the room
|
|
// memberships. If the servername is not supplied then the local server will be
|
|
// checked instead using a faster code path.
|
|
// TODO: This should probably be replaced by an API call.
|
|
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
|
|
info, err := db.RoomInfo(ctx, roomID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if info == nil {
|
|
return false, fmt.Errorf("unknown room %s", roomID)
|
|
}
|
|
|
|
if serverName == "" {
|
|
return db.GetLocalServerInRoom(ctx, info.RoomNID)
|
|
}
|
|
|
|
eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
events, err := db.Events(ctx, eventNIDs)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
gmslEvents := make([]*gomatrixserverlib.Event, len(events))
|
|
for i := range events {
|
|
gmslEvents[i] = events[i].Event
|
|
}
|
|
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
|
|
}
|
|
|
|
func IsInvitePending(
|
|
ctx context.Context, db storage.Database,
|
|
roomID, userID string,
|
|
) (bool, string, string, error) {
|
|
// Look up the room NID for the supplied room ID.
|
|
info, err := db.RoomInfo(ctx, roomID)
|
|
if err != nil {
|
|
return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
|
|
}
|
|
if info == nil {
|
|
return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
|
|
}
|
|
|
|
// Look up the state key NID for the supplied user ID.
|
|
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
|
|
if err != nil {
|
|
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
|
}
|
|
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
|
if !targetUserFound {
|
|
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
|
}
|
|
|
|
// Let's see if we have an event active for the user in the room. If
|
|
// we do then it will contain a server name that we can direct the
|
|
// send_leave to.
|
|
senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
|
|
if err != nil {
|
|
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
|
|
}
|
|
if len(senderUserNIDs) == 0 {
|
|
return false, "", "", nil
|
|
}
|
|
userNIDToEventID := make(map[types.EventStateKeyNID]string)
|
|
for i, nid := range senderUserNIDs {
|
|
userNIDToEventID[nid] = eventIDs[i]
|
|
}
|
|
|
|
// Look up the user ID from the NID.
|
|
senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
|
|
if err != nil {
|
|
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
|
|
}
|
|
if len(senderUsers) == 0 {
|
|
return false, "", "", fmt.Errorf("no senderUsers")
|
|
}
|
|
|
|
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
|
|
if !senderUserFound {
|
|
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
|
|
}
|
|
|
|
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
|
|
}
|
|
|
|
// GetMembershipsAtState filters the state events to
|
|
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
|
// Returns an error if there was an issue fetching the events.
|
|
func GetMembershipsAtState(
|
|
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
|
|
) ([]types.Event, error) {
|
|
|
|
var eventNIDs []types.EventNID
|
|
for _, entry := range stateEntries {
|
|
// Filter the events to retrieve to only keep the membership events
|
|
if entry.EventTypeNID == types.MRoomMemberNID {
|
|
eventNIDs = append(eventNIDs, entry.EventNID)
|
|
}
|
|
}
|
|
|
|
// Get all of the events in this state
|
|
stateEvents, err := db.Events(ctx, eventNIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !joinedOnly {
|
|
return stateEvents, nil
|
|
}
|
|
|
|
// Filter the events to only keep the "join" membership events
|
|
var events []types.Event
|
|
for _, event := range stateEvents {
|
|
membership, err := event.Membership()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if membership == gomatrixserverlib.Join {
|
|
events = append(events, event)
|
|
}
|
|
}
|
|
|
|
return events, nil
|
|
}
|
|
|
|
func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
|
|
roomState := state.NewStateResolution(db, info)
|
|
// Lookup the event NID
|
|
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
eventIDs := []string{eIDs[eventNID]}
|
|
|
|
prevState, err := db.StateAtEventIDs(ctx, eventIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Fetch the state as it was when this event was fired
|
|
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
|
}
|
|
|
|
func LoadEvents(
|
|
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
|
|
) ([]*gomatrixserverlib.Event, error) {
|
|
stateEvents, err := db.Events(ctx, eventNIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]*gomatrixserverlib.Event, len(stateEvents))
|
|
for i := range stateEvents {
|
|
result[i] = stateEvents[i].Event
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func LoadStateEvents(
|
|
ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
|
|
) ([]*gomatrixserverlib.Event, error) {
|
|
eventNIDs := make([]types.EventNID, len(stateEntries))
|
|
for i := range stateEntries {
|
|
eventNIDs[i] = stateEntries[i].EventNID
|
|
}
|
|
return LoadEvents(ctx, db, eventNIDs)
|
|
}
|
|
|
|
func CheckServerAllowedToSeeEvent(
|
|
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
|
|
) (bool, error) {
|
|
roomState := state.NewStateResolution(db, info)
|
|
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
|
|
}
|
|
|
|
// Extract all of the event state key NIDs from the room state.
|
|
var stateKeyNIDs []types.EventStateKeyNID
|
|
for _, entry := range stateEntries {
|
|
stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID)
|
|
}
|
|
|
|
// Then request those state key NIDs from the database.
|
|
stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs)
|
|
if err != nil {
|
|
return false, fmt.Errorf("db.EventStateKeys: %w", err)
|
|
}
|
|
|
|
// If the event state key doesn't match the given servername
|
|
// then we'll filter it out. This does preserve state keys that
|
|
// are "" since these will contain history visibility etc.
|
|
for nid, key := range stateKeys {
|
|
if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
|
|
delete(stateKeys, nid)
|
|
}
|
|
}
|
|
|
|
// Now filter through all of the state events for the room.
|
|
// If the state key NID appears in the list of valid state
|
|
// keys then we'll add it to the list of filtered entries.
|
|
var filteredEntries []types.StateEntry
|
|
for _, entry := range stateEntries {
|
|
if _, ok := stateKeys[entry.EventStateKeyNID]; ok {
|
|
filteredEntries = append(filteredEntries, entry)
|
|
}
|
|
}
|
|
|
|
if len(filteredEntries) == 0 {
|
|
return false, nil
|
|
}
|
|
|
|
stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
|
}
|
|
|
|
// TODO: Remove this when we have tests to assert correctness of this function
|
|
func ScanEventTree(
|
|
ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
|
|
serverName gomatrixserverlib.ServerName,
|
|
) ([]types.EventNID, error) {
|
|
var resultNIDs []types.EventNID
|
|
var err error
|
|
var allowed bool
|
|
var events []types.Event
|
|
var next []string
|
|
var pre string
|
|
|
|
// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
|
|
// Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
|
|
// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
|
|
// duplicate events being sent in response to /backfill requests.
|
|
initialIgnoreList := make(map[string]bool, len(visited))
|
|
for k, v := range visited {
|
|
initialIgnoreList[k] = v
|
|
}
|
|
|
|
resultNIDs = make([]types.EventNID, 0, limit)
|
|
|
|
var checkedServerInRoom bool
|
|
var isServerInRoom bool
|
|
|
|
// Loop through the event IDs to retrieve the requested events and go
|
|
// through the whole tree (up to the provided limit) using the events'
|
|
// "prev_event" key.
|
|
BFSLoop:
|
|
for len(front) > 0 {
|
|
// Prevent unnecessary allocations: reset the slice only when not empty.
|
|
if len(next) > 0 {
|
|
next = make([]string, 0)
|
|
}
|
|
// Retrieve the events to process from the database.
|
|
events, err = db.EventsFromIDs(ctx, front)
|
|
if err != nil {
|
|
return resultNIDs, err
|
|
}
|
|
|
|
if !checkedServerInRoom && len(events) > 0 {
|
|
// It's nasty that we have to extract the room ID from an event, but many federation requests
|
|
// only talk in event IDs, no room IDs at all (!!!)
|
|
ev := events[0]
|
|
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
|
|
if err != nil {
|
|
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
|
|
}
|
|
checkedServerInRoom = true
|
|
}
|
|
|
|
for _, ev := range events {
|
|
// Break out of the loop if the provided limit is reached.
|
|
if len(resultNIDs) == limit {
|
|
break BFSLoop
|
|
}
|
|
|
|
if !initialIgnoreList[ev.EventID()] {
|
|
// Update the list of events to retrieve.
|
|
resultNIDs = append(resultNIDs, ev.EventNID)
|
|
}
|
|
// Loop through the event's parents.
|
|
for _, pre = range ev.PrevEventIDs() {
|
|
// Only add an event to the list of next events to process if it
|
|
// hasn't been seen before.
|
|
if !visited[pre] {
|
|
visited[pre] = true
|
|
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
|
|
if err != nil {
|
|
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
|
"Error checking if allowed to see event",
|
|
)
|
|
// drop the error, as we will often error at the DB level if we don't have the prev_event itself. Let's
|
|
// just return what we have.
|
|
return resultNIDs, nil
|
|
}
|
|
|
|
// If the event hasn't been seen before and the HS
|
|
// requesting to retrieve it is allowed to do so, add it to
|
|
// the list of events to retrieve.
|
|
if allowed {
|
|
next = append(next, pre)
|
|
} else {
|
|
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Repeat the same process with the parent events we just processed.
|
|
front = next
|
|
}
|
|
|
|
return resultNIDs, err
|
|
}
|
|
|
|
func QueryLatestEventsAndState(
|
|
ctx context.Context, db storage.Database,
|
|
request *api.QueryLatestEventsAndStateRequest,
|
|
response *api.QueryLatestEventsAndStateResponse,
|
|
) error {
|
|
roomInfo, err := db.RoomInfo(ctx, request.RoomID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if roomInfo == nil || roomInfo.IsStub {
|
|
response.RoomExists = false
|
|
return nil
|
|
}
|
|
|
|
roomState := state.NewStateResolution(db, roomInfo)
|
|
response.RoomExists = true
|
|
response.RoomVersion = roomInfo.RoomVersion
|
|
|
|
var currentStateSnapshotNID types.StateSnapshotNID
|
|
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
|
|
db.LatestEventIDs(ctx, roomInfo.RoomNID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var stateEntries []types.StateEntry
|
|
if len(request.StateToFetch) == 0 {
|
|
// Look up all room state.
|
|
stateEntries, err = roomState.LoadStateAtSnapshot(
|
|
ctx, currentStateSnapshotNID,
|
|
)
|
|
} else {
|
|
// Look up the current state for the requested tuples.
|
|
stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples(
|
|
ctx, currentStateSnapshotNID, request.StateToFetch,
|
|
)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
stateEvents, err := LoadStateEvents(ctx, db, stateEntries)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, event := range stateEvents {
|
|
response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion))
|
|
}
|
|
|
|
return nil
|
|
}
|