diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 41004cf51..30406a155 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -202,6 +202,14 @@ func SendJoin( } } + // Check that the event is from the server sending the request. + if event.Origin() != request.Origin() { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"), + } + } + // Check that a state key is provided. if event.StateKey() == nil || event.StateKeyEquals("") { return util.JSONResponse{ @@ -216,6 +224,22 @@ func SendJoin( } } + // Check that the sender belongs to the server that is sending us + // the request. By this point we've already asserted that the sender + // and the state key are equal so we don't need to check both. + var domain gomatrixserverlib.ServerName + if _, domain, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The sender of the join is invalid"), + } + } else if domain != request.Origin() { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The sender of the join must belong to the origin server"), + } + } + // Check that the room ID is correct. if event.RoomID() != roomID { return util.JSONResponse{ @@ -242,14 +266,6 @@ func SendJoin( } } - // Check that the event is from the server sending the request. - if event.Origin() != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"), - } - } - // Check that this is in fact a join event membership, err := event.Membership() if err != nil { diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index edc153b7f..1fe25e38a 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -375,11 +375,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room defer span.Finish() var res parsedRespState - roomInfo, err := t.db.RoomInfo(ctx, roomID) - if err != nil { - return nil - } - roomState := state.NewStateResolution(t.db, roomInfo) + roomState := state.NewStateResolution(t.db, t.roomInfo) stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID}) if err != nil { util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to get state after %s locally", eventID) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 50fc98ed2..c41e1ea67 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -399,8 +399,8 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( if err != nil { return err } - if info == nil { - return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) + if info == nil || info.IsStub() { + return nil } response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom, diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index a7ae26d43..42c0c8f2d 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -217,6 +217,14 @@ func (u *RoomUpdater) SetLatestEvents( roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, currentStateSnapshotNID types.StateSnapshotNID, ) error { + switch { + case len(latest) == 0: + return fmt.Errorf("cannot set latest events with no latest event references") + case currentStateSnapshotNID == 0: + return fmt.Errorf("cannot set latest events with invalid state snapshot NID") + case lastEventNIDSent == 0: + return fmt.Errorf("cannot set latest events with invalid latest event NID") + } eventNIDs := make([]types.EventNID, len(latest)) for i := range latest { eventNIDs[i] = latest[i].EventNID @@ -229,8 +237,10 @@ func (u *RoomUpdater) SetLatestEvents( // Since it's entirely possible that this types.RoomInfo came from the // cache, we should make sure to update that entry so that the next run // works from live data. - u.roomInfo.SetStateSnapshotNID(currentStateSnapshotNID) - u.roomInfo.SetIsStub(false) + if u.roomInfo != nil { + u.roomInfo.SetStateSnapshotNID(currentStateSnapshotNID) + u.roomInfo.SetIsStub(false) + } return nil }) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 05b897149..9e6a4142c 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -156,15 +156,30 @@ func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo } func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { - if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok && roomInfo != nil { + roomInfo, ok := d.Cache.GetRoomInfo(roomID) + if ok && roomInfo != nil && !roomInfo.IsStub() { + // The data that's in the cache is not stubby, so return it. return roomInfo, nil } - roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID) - if err == nil && roomInfo != nil { - d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) - d.Cache.StoreRoomInfo(roomID, roomInfo) + // At this point we either don't have an entry in the cache, or + // it is stubby, so let's check the roomserver_rooms table again. + roomInfoFromDB, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID) + if err != nil { + return nil, err } - return roomInfo, err + // If we have a stubby cache entry already, update it and return + // the reference to the cache entry. + if roomInfo != nil { + roomInfo.CopyFrom(roomInfoFromDB) + return roomInfo, nil + } + // Otherwise, try to admit the data into the cache and return the + // new reference from the database. + if roomInfoFromDB != nil { + d.Cache.StoreRoomServerRoomID(roomInfoFromDB.RoomNID, roomID) + d.Cache.StoreRoomInfo(roomID, roomInfoFromDB) + } + return roomInfoFromDB, err } func (d *Database) AddState( @@ -676,7 +691,7 @@ func (d *Database) storeEvent( succeeded := false if updater == nil { var roomInfo *types.RoomInfo - roomInfo, err = d.RoomInfo(ctx, event.RoomID()) + roomInfo, err = d.roomInfo(ctx, txn, event.RoomID()) if err != nil { return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) } diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 726659ea0..f40980994 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -310,3 +310,16 @@ func (r *RoomInfo) SetIsStub(isStub bool) { defer r.mu.Unlock() r.isStub = isStub } + +func (r *RoomInfo) CopyFrom(r2 *RoomInfo) { + r.mu.Lock() + defer r.mu.Unlock() + + r2.mu.RLock() + defer r2.mu.RUnlock() + + r.RoomNID = r2.RoomNID + r.RoomVersion = r2.RoomVersion + r.stateSnapshotNID = r2.stateSnapshotNID + r.isStub = r2.isStub +}