fix ApplyHistoryVisibilityFilter to work with pseudo IDs

This commit is contained in:
Sam Wedgwood 2023-08-10 18:46:39 +01:00
parent 64c4119127
commit 6d3bc3937b
10 changed files with 161 additions and 118 deletions

View file

@ -141,11 +141,28 @@ type QueryRoomHierarchyAPI interface {
QueryNextRoomHierarchyPage(ctx context.Context, walker RoomHierarchyWalker, limit int) ([]fclient.RoomHierarchyRoom, *RoomHierarchyWalker, error) QueryNextRoomHierarchyPage(ctx context.Context, walker RoomHierarchyWalker, limit int) ([]fclient.RoomHierarchyRoom, *RoomHierarchyWalker, error)
} }
type QueryMembershipAPI interface {
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to *types.HeaderedEvent of membership events.
QueryMembershipAtEvent(
ctx context.Context,
roomID spec.RoomID,
eventIDs []string,
senderID spec.SenderID,
) (map[string]*types.HeaderedEvent, error)
}
// API functions required by the syncapi // API functions required by the syncapi
type SyncRoomserverAPI interface { type SyncRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
QueryMembershipAPI
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
@ -155,12 +172,6 @@ type SyncRoomserverAPI interface {
req *QueryEventsByIDRequest, req *QueryEventsByIDRequest,
res *QueryEventsByIDResponse, res *QueryEventsByIDResponse,
) error ) error
// Query the membership event for an user for a room.
QueryMembershipForUser(
ctx context.Context,
req *QueryMembershipForUserRequest,
res *QueryMembershipForUserResponse,
) error
// Query the state after a list of events in a room from the room server. // Query the state after a list of events in a room from the room server.
QueryStateAfterEvents( QueryStateAfterEvents(
@ -175,14 +186,6 @@ type SyncRoomserverAPI interface {
req *PerformBackfillRequest, req *PerformBackfillRequest,
res *PerformBackfillResponse, res *PerformBackfillResponse,
) error ) error
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to a slice of types.HeaderedEvent.
QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error
} }
type AppserviceRoomserverAPI interface { type AppserviceRoomserverAPI interface {
@ -278,15 +281,12 @@ type FederationRoomserverAPI interface {
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
QueryRoomHierarchyAPI QueryRoomHierarchyAPI
QueryMembershipAPI
UserRoomPrivateKeyCreator UserRoomPrivateKeyCreator
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID. // which room to use by querying the first events roomID.

View file

@ -132,6 +132,8 @@ type QueryMembershipForUserResponse struct {
// True if the user asked to forget this room. // True if the user asked to forget this room.
IsRoomForgotten bool `json:"is_room_forgotten"` IsRoomForgotten bool `json:"is_room_forgotten"`
RoomExists bool `json:"room_exists"` RoomExists bool `json:"room_exists"`
// The sender ID of the user in the room, if it exists
SenderID *spec.SenderID
} }
// QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom // QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom
@ -414,22 +416,6 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
return nil return nil
} }
// QueryMembershipAtEventRequest requests the membership event for a user
// for a list of eventIDs.
type QueryMembershipAtEventRequest struct {
RoomID string
EventIDs []string
UserID string
}
// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
type QueryMembershipAtEventResponse struct {
// Membership is a map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility.
Membership map[string]*types.HeaderedEvent `json:"membership"`
}
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a // QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a
// a room with anymore. This is used to cleanup stale device list entries, where we would // a room with anymore. This is used to cleanup stale device list entries, where we would
// otherwise keep on trying to get device lists. // otherwise keep on trying to get device lists.

View file

@ -255,6 +255,8 @@ func (r *Queryer) QueryMembershipForUser(
// //
// If sender ID is nil, then act as if the provided sender is not a member of the room. // If sender ID is nil, then act as if the provided sender is not a member of the room.
func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID spec.RoomID, senderID *spec.SenderID, response *api.QueryMembershipForUserResponse) error { func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID spec.RoomID, senderID *spec.SenderID, response *api.QueryMembershipForUserResponse) error {
response.SenderID = senderID
info, err := r.DB.RoomInfo(ctx, roomID.String()) info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil { if err != nil {
return err return err
@ -300,49 +302,52 @@ func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID
// QueryMembershipAtEvent returns the known memberships at a given event. // QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned // If the state before an event is not known, an empty list will be returned
// for that event instead. // for that event instead.
//
// Returned map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility.
func (r *Queryer) QueryMembershipAtEvent( func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context, ctx context.Context,
request *api.QueryMembershipAtEventRequest, roomID spec.RoomID,
response *api.QueryMembershipAtEventResponse, eventIDs []string,
) error { senderID spec.SenderID,
response.Membership = make(map[string]*types.HeaderedEvent) ) (map[string]*types.HeaderedEvent, error) {
info, err := r.DB.RoomInfo(ctx, roomID.String())
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil { if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err) return nil, fmt.Errorf("unable to get roomInfo: %w", err)
} }
if info == nil { if info == nil {
return fmt.Errorf("no roomInfo found") return nil, fmt.Errorf("no roomInfo found")
} }
// get the users stateKeyNID // get the users stateKeyNID
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID}) stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil { if err != nil {
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err) return nil, fmt.Errorf("unable to get stateKeyNIDs for %s: %w", senderID, err)
} }
if _, ok := stateKeyNIDs[request.UserID]; !ok { if _, ok := stateKeyNIDs[string(senderID)]; !ok {
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) return nil, fmt.Errorf("requested stateKeyNID for %s was not found", senderID)
} }
response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...) eventIDMembershipMap, err := r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[string(senderID)], info, eventIDs...)
switch err { switch err {
case nil: case nil:
return nil return eventIDMembershipMap, nil
case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
default: default:
return err return eventIDMembershipMap, err
} }
response.Membership = make(map[string]*types.HeaderedEvent) eventIDMembershipMap = make(map[string]*types.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r) stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, eventIDs, stateKeyNIDs[string(senderID)], r)
if err != nil { if err != nil {
return fmt.Errorf("unable to get state before event: %w", err) return eventIDMembershipMap, fmt.Errorf("unable to get state before event: %w", err)
} }
// If we only have one or less state entries, we can short circuit the below // If we only have one or less state entries, we can short circuit the below
// loop and avoid hitting the database // loop and avoid hitting the database
allStateEventNIDs := make(map[types.EventNID]types.StateEntry) allStateEventNIDs := make(map[types.EventNID]types.StateEntry)
for _, eventID := range request.EventIDs { for _, eventID := range eventIDs {
stateEntry := stateEntries[eventID] stateEntry := stateEntries[eventID]
for _, s := range stateEntry { for _, s := range stateEntry {
allStateEventNIDs[s.EventNID] = s allStateEventNIDs[s.EventNID] = s
@ -355,10 +360,10 @@ func (r *Queryer) QueryMembershipAtEvent(
} }
var memberships []types.Event var memberships []types.Event
for _, eventID := range request.EventIDs { for _, eventID := range eventIDs {
stateEntry, ok := stateEntries[eventID] stateEntry, ok := stateEntries[eventID]
if !ok || len(stateEntry) == 0 { if !ok || len(stateEntry) == 0 {
response.Membership[eventID] = nil eventIDMembershipMap[eventID] = nil
continue continue
} }
@ -372,7 +377,7 @@ func (r *Queryer) QueryMembershipAtEvent(
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
} }
if err != nil { if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err) return eventIDMembershipMap, fmt.Errorf("unable to get memberships at state: %w", err)
} }
// Iterate over all membership events we got. Given we only query the membership for // Iterate over all membership events we got. Given we only query the membership for
@ -380,13 +385,13 @@ func (r *Queryer) QueryMembershipAtEvent(
// a given event, overwrite any other existing membership events. // a given event, overwrite any other existing membership events.
for i := range memberships { for i := range memberships {
ev := memberships[i] ev := memberships[i]
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU} eventIDMembershipMap[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
} }
} }
} }
return nil return eventIDMembershipMap, nil
} }
// QueryMembershipsForRoom implements api.RoomserverInternalAPI // QueryMembershipsForRoom implements api.RoomserverInternalAPI

View file

@ -16,6 +16,7 @@ package internal
import ( import (
"context" "context"
"fmt"
"math" "math"
"time" "time"
@ -101,13 +102,15 @@ func (ev eventVisibility) allowed() (allowed bool) {
// ApplyHistoryVisibilityFilter applies the room history visibility filter on types.HeaderedEvents. // ApplyHistoryVisibilityFilter applies the room history visibility filter on types.HeaderedEvents.
// Returns the filtered events and an error, if any. // Returns the filtered events and an error, if any.
//
// This function assumes that all provided events are from the same room.
func ApplyHistoryVisibilityFilter( func ApplyHistoryVisibilityFilter(
ctx context.Context, ctx context.Context,
syncDB storage.DatabaseTransaction, syncDB storage.DatabaseTransaction,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*types.HeaderedEvent, events []*types.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{}, alwaysIncludeEventIDs map[string]struct{},
userID, endpoint string, userID spec.UserID, endpoint string,
) ([]*types.HeaderedEvent, error) { ) ([]*types.HeaderedEvent, error) {
if len(events) == 0 { if len(events) == 0 {
return events, nil return events, nil
@ -115,15 +118,29 @@ func ApplyHistoryVisibilityFilter(
start := time.Now() start := time.Now()
// try to get the current membership of the user // try to get the current membership of the user
membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64) membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID.String(), math.MaxInt64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Get the mapping from eventID -> eventVisibility // Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*types.HeaderedEvent, 0, len(events)) eventsFiltered := make([]*types.HeaderedEvent, 0, len(events))
visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) firstEvRoomID, err := spec.NewRoomID(events[0].RoomID())
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *firstEvRoomID, userID)
if err != nil {
return nil, err
}
visibilities := visibilityForEvents(ctx, rsAPI, events, senderID, *firstEvRoomID)
for _, ev := range events { for _, ev := range events {
// Validate same room assumption
if ev.RoomID() != firstEvRoomID.String() {
return nil, fmt.Errorf("events from different rooms supplied to ApplyHistoryVisibilityFilter")
}
evVis := visibilities[ev.EventID()] evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent evVis.membershipCurrent = membershipCurrent
// Always include specific state events for /sync responses // Always include specific state events for /sync responses
@ -133,23 +150,15 @@ func ApplyHistoryVisibilityFilter(
continue continue
} }
} }
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
user, err := spec.NewUserID(userID, true) // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
if err != nil { if senderID != nil {
return nil, err
}
roomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user)
if err == nil && senderID != nil {
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(*senderID)) { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(*senderID)) {
eventsFiltered = append(eventsFiltered, ev) eventsFiltered = append(eventsFiltered, ev)
continue continue
} }
} }
// Always allow history evVis events on boundaries. This is done // Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive // by setting the effective evVis to the least restrictive
// of the old vs new. // of the old vs new.
@ -178,13 +187,13 @@ func ApplyHistoryVisibilityFilter(
} }
// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership // visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership
// of `userID` at the given event. // of `senderID` at the given event. If provided sender ID is nil, assume that membership is Leave
// Returns an error if the roomserver can't calculate the memberships. // Returns an error if the roomserver can't calculate the memberships.
func visibilityForEvents( func visibilityForEvents(
ctx context.Context, ctx context.Context,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*types.HeaderedEvent, events []*types.HeaderedEvent,
userID, roomID string, senderID *spec.SenderID, roomID spec.RoomID,
) map[string]eventVisibility { ) map[string]eventVisibility {
eventIDs := make([]string, len(events)) eventIDs := make([]string, len(events))
for i := range events { for i := range events {
@ -194,16 +203,14 @@ func visibilityForEvents(
result := make(map[string]eventVisibility, len(eventIDs)) result := make(map[string]eventVisibility, len(eventIDs))
// get the membership events for all eventIDs // get the membership events for all eventIDs
membershipResp := &api.QueryMembershipAtEventResponse{} var err error
membershipEvents := make(map[string]*types.HeaderedEvent)
err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ if senderID != nil {
RoomID: roomID, membershipEvents, err = rsAPI.QueryMembershipAtEvent(ctx, roomID, eventIDs, *senderID)
EventIDs: eventIDs,
UserID: userID,
}, membershipResp)
if err != nil { if err != nil {
logrus.WithError(err).Error("visibilityForEvents: failed to fetch membership at event, defaulting to 'leave'") logrus.WithError(err).Error("visibilityForEvents: failed to fetch membership at event, defaulting to 'leave'")
} }
}
// Create a map from eventID -> eventVisibility // Create a map from eventID -> eventVisibility
for _, event := range events { for _, event := range events {
@ -212,7 +219,7 @@ func visibilityForEvents(
membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility, visibility: event.Visibility,
} }
ev, ok := membershipResp.Membership[eventID] ev, ok := membershipEvents[eventID]
if !ok || ev == nil { if !ok || ev == nil {
result[eventID] = vis result[eventID] = vis
continue continue

View file

@ -138,7 +138,7 @@ func Context(
// verify the user is allowed to see the context for this room/event // verify the user is allowed to see the context for this room/event
startTime := time.Now() startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, *userID, "context")
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
return util.JSONResponse{ return util.JSONResponse{
@ -176,7 +176,7 @@ func Context(
} }
startTime = time.Now() startTime = time.Now()
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID) eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, *userID)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
return util.JSONResponse{ return util.JSONResponse{
@ -257,7 +257,7 @@ func Context(
func applyHistoryVisibilityOnContextEvents( func applyHistoryVisibilityOnContextEvents(
ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI, ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*rstypes.HeaderedEvent, eventsBefore, eventsAfter []*rstypes.HeaderedEvent,
userID string, userID spec.UserID,
) (filteredBefore, filteredAfter []*rstypes.HeaderedEvent, err error) { ) (filteredBefore, filteredAfter []*rstypes.HeaderedEvent, err error) {
eventIDsBefore := make(map[string]struct{}, len(eventsBefore)) eventIDsBefore := make(map[string]struct{}, len(eventsBefore))
eventIDsAfter := make(map[string]struct{}, len(eventsAfter)) eventIDsAfter := make(map[string]struct{}, len(eventsAfter))

View file

@ -37,7 +37,7 @@ import (
func GetEvent( func GetEvent(
req *http.Request, req *http.Request,
device *userapi.Device, device *userapi.Device,
roomID string, rawRoomID string,
eventID string, eventID string,
cfg *config.SyncAPI, cfg *config.SyncAPI,
syncDB storage.Database, syncDB storage.Database,
@ -47,7 +47,7 @@ func GetEvent(
db, err := syncDB.NewDatabaseTransaction(ctx) db, err := syncDB.NewDatabaseTransaction(ctx)
logger := util.GetLogger(ctx).WithFields(logrus.Fields{ logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": eventID, "event_id": eventID,
"room_id": roomID, "room_id": rawRoomID,
}) })
if err != nil { if err != nil {
logger.WithError(err).Error("GetEvent: syncDB.NewDatabaseTransaction failed") logger.WithError(err).Error("GetEvent: syncDB.NewDatabaseTransaction failed")
@ -57,6 +57,14 @@ func GetEvent(
} }
} }
roomID, err := spec.NewRoomID(rawRoomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid room ID"),
}
}
events, err := db.Events(ctx, []string{eventID}) events, err := db.Events(ctx, []string{eventID})
if err != nil { if err != nil {
logger.WithError(err).Error("GetEvent: syncDB.Events failed") logger.WithError(err).Error("GetEvent: syncDB.Events failed")
@ -76,13 +84,22 @@ func GetEvent(
} }
// If the request is coming from an appservice, get the user from the request // If the request is coming from an appservice, get the user from the request
userID := device.UserID rawUserID := device.UserID
if asUserID := req.FormValue("user_id"); device.AppserviceID != "" && asUserID != "" { if asUserID := req.FormValue("user_id"); device.AppserviceID != "" && asUserID != "" {
userID = asUserID rawUserID = asUserID
}
userID, err := spec.NewUserID(rawUserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("invalid device.UserID")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
} }
// Apply history visibility to determine if the user is allowed to view the event // Apply history visibility to determine if the user is allowed to view the event
events, err = internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, events, nil, userID, "event") events, err = internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, events, nil, *userID, "event")
if err != nil { if err != nil {
logger.WithError(err).Error("GetEvent: internal.ApplyHistoryVisibilityFilter failed") logger.WithError(err).Error("GetEvent: internal.ApplyHistoryVisibilityFilter failed")
return util.JSONResponse{ return util.JSONResponse{
@ -101,18 +118,14 @@ func GetEvent(
} }
} }
sender := spec.UserID{} senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, events[0].SenderID())
validRoomID, err := spec.NewRoomID(roomID) if err != nil || senderUserID == nil {
if err != nil { util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("QueryUserIDForSender errored or returned nil-user ID when user should be part of a room")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusInternalServerError,
JSON: spec.BadJSON("roomID is invalid"), JSON: spec.Unknown("internal server error"),
} }
} }
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID())
if err == nil && senderUserID != nil {
sender = *senderUserID
}
sk := events[0].StateKey() sk := events[0].StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
@ -131,6 +144,6 @@ func GetEvent(
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk), JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, *senderUserID, sk),
} }
} }

View file

@ -50,6 +50,7 @@ type messagesReq struct {
from *types.TopologyToken from *types.TopologyToken
to *types.TopologyToken to *types.TopologyToken
device *userapi.Device device *userapi.Device
deviceUserID spec.UserID
wasToProvided bool wasToProvided bool
backwardOrdering bool backwardOrdering bool
filter *synctypes.RoomEventFilter filter *synctypes.RoomEventFilter
@ -77,6 +78,15 @@ func OnIncomingMessagesRequest(
) util.JSONResponse { ) util.JSONResponse {
var err error var err error
deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
// NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we // NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we
// expect to be able to write to the database in response to a /messages // expect to be able to write to the database in response to a /messages
// request that requires backfilling from the roomserver or federation. // request that requires backfilling from the roomserver or federation.
@ -240,6 +250,7 @@ func OnIncomingMessagesRequest(
filter: filter, filter: filter,
backwardOrdering: backwardOrdering, backwardOrdering: backwardOrdering,
device: device, device: device,
deviceUserID: *deviceUserID,
} }
clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI) clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI)
@ -359,7 +370,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
// Apply room history visibility filter // Apply room history visibility filter
startTime := time.Now() startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.deviceUserID, "messages")
if err != nil { if err != nil {
return []synctypes.ClientEvent{}, *r.from, *r.to, nil return []synctypes.ClientEvent{}, *r.from, *r.to, nil
} }

View file

@ -43,9 +43,25 @@ func Relations(
req *http.Request, device *userapi.Device, req *http.Request, device *userapi.Device,
syncDB storage.Database, syncDB storage.Database,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
roomID, eventID, relType, eventType string, rawRoomID, eventID, relType, eventType string,
) util.JSONResponse { ) util.JSONResponse {
var err error roomID, err := spec.NewRoomID(rawRoomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid room ID"),
}
}
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
var from, to types.StreamPosition var from, to types.StreamPosition
var limit int var limit int
dir := req.URL.Query().Get("dir") dir := req.URL.Query().Get("dir")
@ -93,7 +109,7 @@ func Relations(
} }
var events []types.StreamEvent var events []types.StreamEvent
events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor(
req.Context(), roomID, eventID, relType, eventType, from, to, dir == "b", limit, req.Context(), roomID.String(), eventID, relType, eventType, from, to, dir == "b", limit,
) )
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -105,12 +121,7 @@ func Relations(
} }
// Apply history visibility to the result events. // Apply history visibility to the result events.
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, device.UserID, "relations") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, *userID, "relations")
if err != nil {
return util.ErrorResponse(err)
}
validRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
@ -120,14 +131,14 @@ func Relations(
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents { for _, event := range filteredEvents {
sender := spec.UserID{} sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View file

@ -562,8 +562,13 @@ func applyHistoryVisibilityFilter(
} }
} }
parsedUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
startTime := time.Now() startTime := time.Now()
events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, *parsedUserID, "sync")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -74,8 +74,13 @@ func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsa
return nil return nil
} }
func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error { func (s *syncRoomserverAPI) QueryMembershipAtEvent(
return nil ctx context.Context,
roomID spec.RoomID,
eventIDs []string,
senderID spec.SenderID,
) (map[string]*rstypes.HeaderedEvent, error) {
return nil, nil
} }
type syncUserAPI struct { type syncUserAPI struct {