Fix statekey usage in roomserver/helpers

This commit is contained in:
Devon Hudson 2023-06-08 15:36:03 -06:00
parent 58d3452b7f
commit aec29c0008
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
6 changed files with 23 additions and 16 deletions

View file

@ -45,7 +45,7 @@ func GetEventAuth(
if event.RoomID() != roomID { if event.RoomID() != roomID {
return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
} }
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -35,10 +35,6 @@ func GetEvent(
eventID string, eventID string,
origin spec.ServerName, origin spec.ServerName,
) util.JSONResponse { ) util.JSONResponse {
err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID)
if err != nil {
return *err
}
// /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string, // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string,
// which results in `QueryEventsByID` to first get the event and use that to determine the roomID. // which results in `QueryEventsByID` to first get the event and use that to determine the roomID.
event, err := fetchEvent(ctx, rsAPI, "", eventID) event, err := fetchEvent(ctx, rsAPI, "", eventID)
@ -46,6 +42,11 @@ func GetEvent(
return *err return *err
} }
err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
if err != nil {
return *err
}
return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{ return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{
Origin: origin, Origin: origin,
OriginServerTS: spec.AsTimestamp(time.Now()), OriginServerTS: spec.AsTimestamp(time.Now()),
@ -62,8 +63,9 @@ func allowedToSeeEvent(
origin spec.ServerName, origin spec.ServerName,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
eventID string, eventID string,
roomID string,
) *util.JSONResponse { ) *util.JSONResponse {
allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID) allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID, roomID)
if err != nil { if err != nil {
resErr := util.ErrorResponse(err) resErr := util.ErrorResponse(err)
return &resErr return &resErr

View file

@ -116,7 +116,7 @@ func getState(
if event.RoomID() != roomID { if event.RoomID() != roomID {
return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
} }
resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
if resErr != nil { if resErr != nil {
return nil, nil, resErr return nil, nil, resErr
} }

View file

@ -242,7 +242,7 @@ type FederationRoomserverAPI interface {
// Query missing events for a room from roomserver // Query missing events for a room from roomserver
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error

View file

@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"sort" "sort"
"strings"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -265,7 +264,7 @@ func LoadStateEvents(
} }
func CheckServerAllowedToSeeEvent( func CheckServerAllowedToSeeEvent(
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, isServerInRoom bool, ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
) (bool, error) { ) (bool, error) {
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
switch err { switch err {
@ -274,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
case tables.OptimisationNotSupportedError: case tables.OptimisationNotSupportedError:
// The database engine didn't support this optimisation, so fall back to using // The database engine didn't support this optimisation, so fall back to using
// the old and slow method // the old and slow method
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName) stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -293,7 +292,7 @@ func CheckServerAllowedToSeeEvent(
} }
func slowGetHistoryVisibilityState( func slowGetHistoryVisibilityState(
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
) ([]gomatrixserverlib.PDU, error) { ) ([]gomatrixserverlib.PDU, error) {
roomState := state.NewStateResolution(db, info) roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
@ -320,8 +319,13 @@ func slowGetHistoryVisibilityState(
// then we'll filter it out. This does preserve state keys that // then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc. // are "" since these will contain history visibility etc.
for nid, key := range stateKeys { for nid, key := range stateKeys {
if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { if key != "" {
delete(stateKeys, nid) userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
if err == nil && userID != nil {
if userID.Domain() != serverName {
delete(stateKeys, nid)
}
}
} }
} }
@ -411,7 +415,7 @@ BFSLoop:
// hasn't been seen before. // hasn't been seen before.
if !visited[pre] { if !visited[pre] {
visited[pre] = true visited[pre] = true
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
if err != nil { if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event", "Error checking if allowed to see event",

View file

@ -483,6 +483,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context, ctx context.Context,
serverName spec.ServerName, serverName spec.ServerName,
eventID string, eventID string,
roomID string,
) (allowed bool, err error) { ) (allowed bool, err error) {
events, err := r.DB.EventNIDs(ctx, []string{eventID}) events, err := r.DB.EventNIDs(ctx, []string{eventID})
if err != nil { if err != nil {
@ -512,7 +513,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
} }
return helpers.CheckServerAllowedToSeeEvent( return helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, info, eventID, serverName, isInRoom, ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
) )
} }