Merge branch 'main' into s7evink/hisvismessages

This commit is contained in:
Neil Alexander 2022-08-02 12:49:32 +01:00 committed by GitHub
commit a0427c4b09
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 24 deletions

View file

@ -202,6 +202,14 @@ func SendJoin(
}
}
// Check that the event is from the server sending the request.
if event.Origin() != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"),
}
}
// Check that a state key is provided.
if event.StateKey() == nil || event.StateKeyEquals("") {
return util.JSONResponse{
@ -216,6 +224,22 @@ func SendJoin(
}
}
// Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both.
var domain gomatrixserverlib.ServerName
if _, domain, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender of the join is invalid"),
}
} else if domain != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender of the join must belong to the origin server"),
}
}
// Check that the room ID is correct.
if event.RoomID() != roomID {
return util.JSONResponse{
@ -242,14 +266,6 @@ func SendJoin(
}
}
// Check that the event is from the server sending the request.
if event.Origin() != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"),
}
}
// Check that this is in fact a join event
membership, err := event.Membership()
if err != nil {

View file

@ -375,11 +375,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room
defer span.Finish()
var res parsedRespState
roomInfo, err := t.db.RoomInfo(ctx, roomID)
if err != nil {
return nil
}
roomState := state.NewStateResolution(t.db, roomInfo)
roomState := state.NewStateResolution(t.db, t.roomInfo)
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
if err != nil {
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to get state after %s locally", eventID)

View file

@ -399,8 +399,8 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
if err != nil {
return err
}
if info == nil {
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
if info == nil || info.IsStub() {
return nil
}
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom,

View file

@ -217,6 +217,14 @@ func (u *RoomUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
switch {
case len(latest) == 0:
return fmt.Errorf("cannot set latest events with no latest event references")
case currentStateSnapshotNID == 0:
return fmt.Errorf("cannot set latest events with invalid state snapshot NID")
case lastEventNIDSent == 0:
return fmt.Errorf("cannot set latest events with invalid latest event NID")
}
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
@ -229,8 +237,10 @@ func (u *RoomUpdater) SetLatestEvents(
// Since it's entirely possible that this types.RoomInfo came from the
// cache, we should make sure to update that entry so that the next run
// works from live data.
u.roomInfo.SetStateSnapshotNID(currentStateSnapshotNID)
u.roomInfo.SetIsStub(false)
if u.roomInfo != nil {
u.roomInfo.SetStateSnapshotNID(currentStateSnapshotNID)
u.roomInfo.SetIsStub(false)
}
return nil
})
}

View file

@ -156,15 +156,30 @@ func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo
}
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok && roomInfo != nil {
roomInfo, ok := d.Cache.GetRoomInfo(roomID)
if ok && roomInfo != nil && !roomInfo.IsStub() {
// The data that's in the cache is not stubby, so return it.
return roomInfo, nil
}
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, roomInfo)
// At this point we either don't have an entry in the cache, or
// it is stubby, so let's check the roomserver_rooms table again.
roomInfoFromDB, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
if err != nil {
return nil, err
}
return roomInfo, err
// If we have a stubby cache entry already, update it and return
// the reference to the cache entry.
if roomInfo != nil {
roomInfo.CopyFrom(roomInfoFromDB)
return roomInfo, nil
}
// Otherwise, try to admit the data into the cache and return the
// new reference from the database.
if roomInfoFromDB != nil {
d.Cache.StoreRoomServerRoomID(roomInfoFromDB.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, roomInfoFromDB)
}
return roomInfoFromDB, err
}
func (d *Database) AddState(
@ -676,7 +691,7 @@ func (d *Database) storeEvent(
succeeded := false
if updater == nil {
var roomInfo *types.RoomInfo
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
roomInfo, err = d.roomInfo(ctx, txn, event.RoomID())
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}

View file

@ -310,3 +310,16 @@ func (r *RoomInfo) SetIsStub(isStub bool) {
defer r.mu.Unlock()
r.isStub = isStub
}
func (r *RoomInfo) CopyFrom(r2 *RoomInfo) {
r.mu.Lock()
defer r.mu.Unlock()
r2.mu.RLock()
defer r2.mu.RUnlock()
r.RoomNID = r2.RoomNID
r.RoomVersion = r2.RoomVersion
r.stateSnapshotNID = r2.stateSnapshotNID
r.isStub = r2.isStub
}