diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index bf8e978e1..e3748731e 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -375,38 +375,8 @@ func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib return nil, fmt.Errorf("r.federation.SendJoin: %w", err) } - // A list of events that we have retried, if they were not included in - // the auth events supplied in the send_join. - retries := map[string]bool{} - -retryCheck: - if err = respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil { - switch e := err.(type) { - case gomatrixserverlib.MissingAuthEventError: - // Check that we haven't already retried for this event, prevents - // us from ending up in endless loops - if _, ok := retries[e.AuthEventID]; !ok { - // Ask the server that we're talking to right now for the event - tx, txerr := r.federation.GetEvent(r.req.Context(), server, e.AuthEventID) - if txerr != nil { - return nil, fmt.Errorf("r.federation.GetEvent: %w", txerr) - } - // For each event returned, add it to the auth events. - for _, pdu := range tx.PDUs { - ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, respMakeJoin.RoomVersion) - if everr != nil { - return nil, fmt.Errorf("gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr) - } - respSendJoin.AuthEvents = append(respSendJoin.AuthEvents, ev) - } - // Mark the event as retried and then give the check another go. - retries[e.AuthEventID] = true - goto retryCheck - } - return nil, fmt.Errorf("respSendJoin (after retries): %w", e) - default: - return nil, fmt.Errorf("respSendJoin: %w", err) - } + if err = r.checkSendJoinResponse(event, server, respMakeJoin, respSendJoin); err != nil { + return nil, err } util.GetLogger(r.req.Context()).WithFields(logrus.Fields{ @@ -425,7 +395,7 @@ retryCheck: go func() { ctx := context.Background() if err = r.producer.SendEventWithState( - ctx, //r.req.Context(), + ctx, gomatrixserverlib.RespState(respSendJoin.RespState), event.Headered(respMakeJoin.RoomVersion), ); err != nil { @@ -441,3 +411,47 @@ retryCheck: }{roomID}, }, nil } + +// checkSendJoinResponse checks that all of the signatures are correct +// and that the join is allowed by the supplied state. +func (r joinRoomReq) checkSendJoinResponse( + event gomatrixserverlib.Event, + server gomatrixserverlib.ServerName, + respMakeJoin gomatrixserverlib.RespMakeJoin, + respSendJoin gomatrixserverlib.RespSendJoin, +) error { + // A list of events that we have retried, if they were not included in + // the auth events supplied in the send_join. + retries := map[string]bool{} + +retryCheck: + if err := respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil { + switch e := err.(type) { + case gomatrixserverlib.MissingAuthEventError: + // Check that we haven't already retried for this event, prevents + // us from ending up in endless loops + if _, ok := retries[e.AuthEventID]; !ok { + // Ask the server that we're talking to right now for the event + tx, txerr := r.federation.GetEvent(r.req.Context(), server, e.AuthEventID) + if txerr != nil { + return fmt.Errorf("r.federation.GetEvent: %w", txerr) + } + // For each event returned, add it to the auth events. + for _, pdu := range tx.PDUs { + ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, respMakeJoin.RoomVersion) + if everr != nil { + return fmt.Errorf("gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr) + } + respSendJoin.AuthEvents = append(respSendJoin.AuthEvents, ev) + } + // Mark the event as retried and then give the check another go. + retries[e.AuthEventID] = true + goto retryCheck + } + return fmt.Errorf("respSendJoin (after retries): %w", e) + default: + return fmt.Errorf("respSendJoin: %w", err) + } + } + return nil +}