mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-23 23:03:10 -06:00
merge master
This commit is contained in:
commit
75c3f2df8d
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
|
@ -3,4 +3,4 @@
|
||||||
<!-- Please read CONTRIBUTING.md before submitting your pull request -->
|
<!-- Please read CONTRIBUTING.md before submitting your pull request -->
|
||||||
|
|
||||||
* [ ] I have added any new tests that need to pass to `testfile` as specified in [docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md)
|
* [ ] I have added any new tests that need to pass to `testfile` as specified in [docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md)
|
||||||
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/CONTRIBUTING.md#sign-off)
|
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/docs/CONTRIBUTING.md#sign-off)
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,13 @@ global:
|
||||||
# The path to the signing private key file, used to sign requests and events.
|
# The path to the signing private key file, used to sign requests and events.
|
||||||
private_key: matrix_key.pem
|
private_key: matrix_key.pem
|
||||||
|
|
||||||
# A unique identifier for this private key. Must start with the prefix "ed25519:".
|
# The paths and expiry timestamps (as a UNIX timestamp in millisecond precision)
|
||||||
key_id: ed25519:auto
|
# to old signing private keys that were formerly in use on this domain. These
|
||||||
|
# keys will not be used for federation request or event signing, but will be
|
||||||
|
# provided to any other homeserver that asks when trying to verify old events.
|
||||||
|
# old_private_keys:
|
||||||
|
# - private_key: old_matrix_key.pem
|
||||||
|
# expired_at: 1601024554498
|
||||||
|
|
||||||
# How long a remote server can cache our server signing key before requesting it
|
# How long a remote server can cache our server signing key before requesting it
|
||||||
# again. Increasing this number will reduce the number of requests made by other
|
# again. Increasing this number will reduce the number of requests made by other
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,15 @@ func InviteV2(
|
||||||
keys gomatrixserverlib.JSONVerifier,
|
keys gomatrixserverlib.JSONVerifier,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
inviteReq := gomatrixserverlib.InviteV2Request{}
|
inviteReq := gomatrixserverlib.InviteV2Request{}
|
||||||
if err := json.Unmarshal(request.Content(), &inviteReq); err != nil {
|
err := json.Unmarshal(request.Content(), &inviteReq)
|
||||||
|
switch err.(type) {
|
||||||
|
case gomatrixserverlib.BadJSONError:
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(err.Error()),
|
||||||
|
}
|
||||||
|
case nil:
|
||||||
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.NotJSON("The request body could not be decoded into an invite request. " + err.Error()),
|
JSON: jsonerror.NotJSON("The request body could not be decoded into an invite request. " + err.Error()),
|
||||||
|
|
@ -63,10 +71,17 @@ func InviteV1(
|
||||||
roomVer := gomatrixserverlib.RoomVersionV1
|
roomVer := gomatrixserverlib.RoomVersionV1
|
||||||
body := request.Content()
|
body := request.Content()
|
||||||
event, err := gomatrixserverlib.NewEventFromTrustedJSON(body, false, roomVer)
|
event, err := gomatrixserverlib.NewEventFromTrustedJSON(body, false, roomVer)
|
||||||
if err != nil {
|
switch err.(type) {
|
||||||
|
case gomatrixserverlib.BadJSONError:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.NotJSON("The request body could not be decoded into an invite v1 request: " + err.Error()),
|
JSON: jsonerror.BadJSON(err.Error()),
|
||||||
|
}
|
||||||
|
case nil:
|
||||||
|
default:
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.NotJSON("The request body could not be decoded into an invite v1 request. " + err.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var strippedState []gomatrixserverlib.InviteV2StrippedState
|
var strippedState []gomatrixserverlib.InviteV2StrippedState
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// MakeJoin implements the /make_join API
|
// MakeJoin implements the /make_join API
|
||||||
|
// nolint:gocyclo
|
||||||
func MakeJoin(
|
func MakeJoin(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
request *gomatrixserverlib.FederationRequest,
|
request *gomatrixserverlib.FederationRequest,
|
||||||
|
|
@ -79,6 +80,29 @@ func MakeJoin(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we think we are still joined to the room
|
||||||
|
inRoomReq := &api.QueryServerJoinedToRoomRequest{
|
||||||
|
ServerName: cfg.Matrix.ServerName,
|
||||||
|
RoomID: roomID,
|
||||||
|
}
|
||||||
|
inRoomRes := &api.QueryServerJoinedToRoomResponse{}
|
||||||
|
if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), inRoomReq, inRoomRes); err != nil {
|
||||||
|
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
if !inRoomRes.RoomExists {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: jsonerror.NotFound(fmt.Sprintf("Room ID %q was not found on this server", roomID)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !inRoomRes.IsInRoom {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: jsonerror.NotFound(fmt.Sprintf("Room ID %q has no remaining users on this server", roomID)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Try building an event for the server
|
// Try building an event for the server
|
||||||
builder := gomatrixserverlib.EventBuilder{
|
builder := gomatrixserverlib.EventBuilder{
|
||||||
Sender: userID,
|
Sender: userID,
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -133,6 +136,8 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
|
||||||
var keys gomatrixserverlib.ServerKeys
|
var keys gomatrixserverlib.ServerKeys
|
||||||
|
|
||||||
keys.ServerName = cfg.Matrix.ServerName
|
keys.ServerName = cfg.Matrix.ServerName
|
||||||
|
keys.TLSFingerprints = cfg.TLSFingerPrints
|
||||||
|
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
|
||||||
|
|
||||||
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
|
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
|
||||||
|
|
||||||
|
|
@ -142,9 +147,15 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
keys.TLSFingerprints = cfg.TLSFingerPrints
|
|
||||||
keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{}
|
keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{}
|
||||||
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
|
for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys {
|
||||||
|
keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{
|
||||||
|
VerifyKey: gomatrixserverlib.VerifyKey{
|
||||||
|
Key: gomatrixserverlib.Base64Bytes(oldVerifyKey.PrivateKey),
|
||||||
|
},
|
||||||
|
ExpiredTS: oldVerifyKey.ExpiredAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
toSign, err := json.Marshal(keys.ServerKeyFields)
|
toSign, err := json.Marshal(keys.ServerKeyFields)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -160,3 +171,62 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
|
||||||
|
|
||||||
return &keys, nil
|
return &keys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NotaryKeys(
|
||||||
|
httpReq *http.Request, cfg *config.FederationAPI,
|
||||||
|
fsAPI federationSenderAPI.FederationSenderInternalAPI,
|
||||||
|
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
|
||||||
|
) util.JSONResponse {
|
||||||
|
if req == nil {
|
||||||
|
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
|
||||||
|
if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
|
||||||
|
return *reqErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var response struct {
|
||||||
|
ServerKeys []json.RawMessage `json:"server_keys"`
|
||||||
|
}
|
||||||
|
response.ServerKeys = []json.RawMessage{}
|
||||||
|
|
||||||
|
for serverName := range req.ServerKeys {
|
||||||
|
var keys *gomatrixserverlib.ServerKeys
|
||||||
|
if serverName == cfg.Matrix.ServerName {
|
||||||
|
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
|
||||||
|
keys = k
|
||||||
|
} else {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if k, err := fsAPI.GetServerKeys(httpReq.Context(), serverName); err == nil {
|
||||||
|
keys = &k
|
||||||
|
} else {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if keys == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
j, err := json.Marshal(keys)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to marshal %q response", serverName)
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
js, err := gomatrixserverlib.SignJSON(
|
||||||
|
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, j,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to sign %q response", serverName)
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
response.ServerKeys = append(response.ServerKeys, js)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: response,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,14 @@ func SendLeave(
|
||||||
|
|
||||||
// Decode the event JSON from the request.
|
// Decode the event JSON from the request.
|
||||||
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion)
|
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion)
|
||||||
if err != nil {
|
switch err.(type) {
|
||||||
|
case gomatrixserverlib.BadJSONError:
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(err.Error()),
|
||||||
|
}
|
||||||
|
case nil:
|
||||||
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,26 @@ func Setup(
|
||||||
return LocalKeys(cfg)
|
return LocalKeys(cfg)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
var pkReq *gomatrixserverlib.PublicKeyNotaryLookupRequest
|
||||||
|
serverName := gomatrixserverlib.ServerName(vars["serverName"])
|
||||||
|
keyID := gomatrixserverlib.KeyID(vars["keyID"])
|
||||||
|
if serverName != "" && keyID != "" {
|
||||||
|
pkReq = &gomatrixserverlib.PublicKeyNotaryLookupRequest{
|
||||||
|
ServerKeys: map[gomatrixserverlib.ServerName]map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria{
|
||||||
|
serverName: {
|
||||||
|
keyID: gomatrixserverlib.PublicKeyNotaryQueryCriteria{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NotaryKeys(req, cfg, fsAPI, pkReq)
|
||||||
|
})
|
||||||
|
|
||||||
// Ignore the {keyID} argument as we only have a single server key so we always
|
// Ignore the {keyID} argument as we only have a single server key so we always
|
||||||
// return that key.
|
// return that key.
|
||||||
// Even if we had more than one server key, we would probably still ignore the
|
// Even if we had more than one server key, we would probably still ignore the
|
||||||
|
|
@ -68,6 +88,8 @@ func Setup(
|
||||||
v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet)
|
v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet)
|
||||||
v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet)
|
v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet)
|
||||||
v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet)
|
v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet)
|
||||||
|
v2keysmux.Handle("/query", notaryKeys).Methods(http.MethodPost)
|
||||||
|
v2keysmux.Handle("/query/{serverName}/{keyID}", notaryKeys).Methods(http.MethodGet)
|
||||||
|
|
||||||
v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI(
|
v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI(
|
||||||
"federation_send", cfg.Matrix.ServerName, keys, wakeup,
|
"federation_send", cfg.Matrix.ServerName, keys, wakeup,
|
||||||
|
|
|
||||||
|
|
@ -199,6 +199,15 @@ func (t *testRoomserverAPI) QueryMembershipsForRoom(
|
||||||
return fmt.Errorf("not implemented")
|
return fmt.Errorf("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Query if a server is joined to a room
|
||||||
|
func (t *testRoomserverAPI) QueryServerJoinedToRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryServerJoinedToRoomRequest,
|
||||||
|
response *api.QueryServerJoinedToRoomResponse,
|
||||||
|
) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// Query whether a server is allowed to see an event
|
// Query whether a server is allowed to see an event
|
||||||
func (t *testRoomserverAPI) QueryServerAllowedToSeeEvent(
|
func (t *testRoomserverAPI) QueryServerAllowedToSeeEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ type FederationClient interface {
|
||||||
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
|
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
|
||||||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
||||||
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||||
|
GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error)
|
||||||
|
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FederationClientError is returned from FederationClient methods in the event of a problem.
|
// FederationClientError is returned from FederationClient methods in the event of a problem.
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationsender/api"
|
"github.com/matrix-org/dendrite/federationsender/api"
|
||||||
|
|
@ -23,6 +24,7 @@ type FederationSenderInternalAPI struct {
|
||||||
federation *gomatrixserverlib.FederationClient
|
federation *gomatrixserverlib.FederationClient
|
||||||
keyRing *gomatrixserverlib.KeyRing
|
keyRing *gomatrixserverlib.KeyRing
|
||||||
queues *queue.OutgoingQueues
|
queues *queue.OutgoingQueues
|
||||||
|
joins sync.Map // joins currently in progress
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFederationSenderInternalAPI(
|
func NewFederationSenderInternalAPI(
|
||||||
|
|
@ -187,3 +189,27 @@ func (a *FederationSenderInternalAPI) GetEvent(
|
||||||
}
|
}
|
||||||
return ires.(gomatrixserverlib.Transaction), nil
|
return ires.(gomatrixserverlib.Transaction), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *FederationSenderInternalAPI) GetServerKeys(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName,
|
||||||
|
) (gomatrixserverlib.ServerKeys, error) {
|
||||||
|
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||||
|
return a.federation.GetServerKeys(ctx, s)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.ServerKeys{}, err
|
||||||
|
}
|
||||||
|
return ires.(gomatrixserverlib.ServerKeys), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *FederationSenderInternalAPI) LookupServerKeys(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||||
|
) ([]gomatrixserverlib.ServerKeys, error) {
|
||||||
|
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||||
|
return a.federation.LookupServerKeys(ctx, s, keyRequests)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return []gomatrixserverlib.ServerKeys{}, err
|
||||||
|
}
|
||||||
|
return ires.([]gomatrixserverlib.ServerKeys), nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -37,12 +37,32 @@ func (r *FederationSenderInternalAPI) PerformDirectoryLookup(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type federatedJoin struct {
|
||||||
|
UserID string
|
||||||
|
RoomID string
|
||||||
|
}
|
||||||
|
|
||||||
// PerformJoinRequest implements api.FederationSenderInternalAPI
|
// PerformJoinRequest implements api.FederationSenderInternalAPI
|
||||||
func (r *FederationSenderInternalAPI) PerformJoin(
|
func (r *FederationSenderInternalAPI) PerformJoin(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.PerformJoinRequest,
|
request *api.PerformJoinRequest,
|
||||||
response *api.PerformJoinResponse,
|
response *api.PerformJoinResponse,
|
||||||
) {
|
) {
|
||||||
|
// Check that a join isn't already in progress for this user/room.
|
||||||
|
j := federatedJoin{request.UserID, request.RoomID}
|
||||||
|
if _, found := r.joins.Load(j); found {
|
||||||
|
response.LastError = &gomatrix.HTTPError{
|
||||||
|
Code: 429,
|
||||||
|
Message: `{
|
||||||
|
"errcode": "M_LIMIT_EXCEEDED",
|
||||||
|
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
|
||||||
|
}`, // TODO: Why do none of our error types play nicely with each other?
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.joins.Store(j, nil)
|
||||||
|
defer r.joins.Delete(j)
|
||||||
|
|
||||||
// Look up the supported room versions.
|
// Look up the supported room versions.
|
||||||
var supportedVersions []gomatrixserverlib.RoomVersion
|
var supportedVersions []gomatrixserverlib.RoomVersion
|
||||||
for version := range version.SupportedRoomVersions() {
|
for version := range version.SupportedRoomVersions() {
|
||||||
|
|
@ -186,28 +206,47 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
|
||||||
}
|
}
|
||||||
r.statistics.ForServer(serverName).Success()
|
r.statistics.ForServer(serverName).Success()
|
||||||
|
|
||||||
// Check that the send_join response was valid.
|
// Process the join response in a goroutine. The idea here is
|
||||||
joinCtx := perform.JoinContext(r.federation, r.keyRing)
|
// that we'll try and wait for as long as possible for the work
|
||||||
respState, err := joinCtx.CheckSendJoinResponse(
|
// to complete, but if the client does give up waiting, we'll
|
||||||
ctx, event, serverName, respSendJoin,
|
// still continue to process the join anyway so that we don't
|
||||||
)
|
// waste the effort.
|
||||||
if err != nil {
|
var cancel context.CancelFunc
|
||||||
return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err)
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
}
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// If we successfully performed a send_join above then the other
|
// Check that the send_join response was valid.
|
||||||
// server now thinks we're a part of the room. Send the newly
|
joinCtx := perform.JoinContext(r.federation, r.keyRing)
|
||||||
// returned state to the roomserver to update our local view.
|
respState, err := joinCtx.CheckSendJoinResponse(
|
||||||
headeredEvent := event.Headered(respMakeJoin.RoomVersion)
|
ctx, event, serverName, respSendJoin,
|
||||||
if err = roomserverAPI.SendEventWithRewrite(
|
)
|
||||||
ctx, r.rsAPI,
|
if err != nil {
|
||||||
respState,
|
logrus.WithFields(logrus.Fields{
|
||||||
headeredEvent,
|
"room_id": roomID,
|
||||||
nil,
|
"user_id": userID,
|
||||||
); err != nil {
|
}).WithError(err).Error("Failed to process room join response")
|
||||||
return fmt.Errorf("r.producer.SendEventWithState: %w", err)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we successfully performed a send_join above then the other
|
||||||
|
// server now thinks we're a part of the room. Send the newly
|
||||||
|
// returned state to the roomserver to update our local view.
|
||||||
|
if err = roomserverAPI.SendEventWithRewrite(
|
||||||
|
ctx, r.rsAPI,
|
||||||
|
respState,
|
||||||
|
event.Headered(respMakeJoin.RoomVersion),
|
||||||
|
nil,
|
||||||
|
); err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"room_id": roomID,
|
||||||
|
"user_id": userID,
|
||||||
|
}).WithError(err).Error("Failed to send room join response to roomserver")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,13 +24,15 @@ const (
|
||||||
FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive"
|
FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive"
|
||||||
FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU"
|
FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU"
|
||||||
|
|
||||||
FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices"
|
FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices"
|
||||||
FederationSenderClaimKeysPath = "/federationsender/client/claimKeys"
|
FederationSenderClaimKeysPath = "/federationsender/client/claimKeys"
|
||||||
FederationSenderQueryKeysPath = "/federationsender/client/queryKeys"
|
FederationSenderQueryKeysPath = "/federationsender/client/queryKeys"
|
||||||
FederationSenderBackfillPath = "/federationsender/client/backfill"
|
FederationSenderBackfillPath = "/federationsender/client/backfill"
|
||||||
FederationSenderLookupStatePath = "/federationsender/client/lookupState"
|
FederationSenderLookupStatePath = "/federationsender/client/lookupState"
|
||||||
FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs"
|
FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs"
|
||||||
FederationSenderGetEventPath = "/federationsender/client/getEvent"
|
FederationSenderGetEventPath = "/federationsender/client/getEvent"
|
||||||
|
FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys"
|
||||||
|
FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
|
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
|
@ -372,3 +374,59 @@ func (h *httpFederationSenderInternalAPI) GetEvent(
|
||||||
}
|
}
|
||||||
return *response.Res, nil
|
return *response.Res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type getServerKeys struct {
|
||||||
|
S gomatrixserverlib.ServerName
|
||||||
|
ServerKeys gomatrixserverlib.ServerKeys
|
||||||
|
Err *api.FederationClientError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpFederationSenderInternalAPI) GetServerKeys(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName,
|
||||||
|
) (gomatrixserverlib.ServerKeys, error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "GetServerKeys")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
request := getServerKeys{
|
||||||
|
S: s,
|
||||||
|
}
|
||||||
|
var response getServerKeys
|
||||||
|
apiURL := h.federationSenderURL + FederationSenderGetServerKeysPath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.ServerKeys{}, err
|
||||||
|
}
|
||||||
|
if response.Err != nil {
|
||||||
|
return gomatrixserverlib.ServerKeys{}, response.Err
|
||||||
|
}
|
||||||
|
return response.ServerKeys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type lookupServerKeys struct {
|
||||||
|
S gomatrixserverlib.ServerName
|
||||||
|
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
|
||||||
|
ServerKeys []gomatrixserverlib.ServerKeys
|
||||||
|
Err *api.FederationClientError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpFederationSenderInternalAPI) LookupServerKeys(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||||
|
) ([]gomatrixserverlib.ServerKeys, error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
request := lookupServerKeys{
|
||||||
|
S: s,
|
||||||
|
KeyRequests: keyRequests,
|
||||||
|
}
|
||||||
|
var response lookupServerKeys
|
||||||
|
apiURL := h.federationSenderURL + FederationSenderLookupServerKeysPath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return []gomatrixserverlib.ServerKeys{}, err
|
||||||
|
}
|
||||||
|
if response.Err != nil {
|
||||||
|
return []gomatrixserverlib.ServerKeys{}, response.Err
|
||||||
|
}
|
||||||
|
return response.ServerKeys, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -263,4 +263,48 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
FederationSenderGetServerKeysPath,
|
||||||
|
httputil.MakeInternalAPI("GetServerKeys", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request getServerKeys
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
res, err := intAPI.GetServerKeys(req.Context(), request.S)
|
||||||
|
if err != nil {
|
||||||
|
ferr, ok := err.(*api.FederationClientError)
|
||||||
|
if ok {
|
||||||
|
request.Err = ferr
|
||||||
|
} else {
|
||||||
|
request.Err = &api.FederationClientError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request.ServerKeys = res
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
FederationSenderLookupServerKeysPath,
|
||||||
|
httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request lookupServerKeys
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests)
|
||||||
|
if err != nil {
|
||||||
|
ferr, ok := err.(*api.FederationClientError)
|
||||||
|
if ok {
|
||||||
|
request.Err = ferr
|
||||||
|
} else {
|
||||||
|
request.Err = &api.FederationClientError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request.ServerKeys = res
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,13 +231,24 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
// If we are backing off this server then wait for the
|
// If we are backing off this server then wait for the
|
||||||
// backoff duration to complete first, or until explicitly
|
// backoff duration to complete first, or until explicitly
|
||||||
// told to retry.
|
// told to retry.
|
||||||
if _, giveUp := oq.statistics.BackoffIfRequired(oq.backingOff, oq.interruptBackoff); giveUp {
|
until, blacklisted := oq.statistics.BackoffInfo()
|
||||||
|
if blacklisted {
|
||||||
// It's been suggested that we should give up because the backoff
|
// It's been suggested that we should give up because the backoff
|
||||||
// has exceeded a maximum allowable value. Clean up the in-memory
|
// has exceeded a maximum allowable value. Clean up the in-memory
|
||||||
// buffers at this point. The PDU clean-up is already on a defer.
|
// buffers at this point. The PDU clean-up is already on a defer.
|
||||||
log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
|
log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if until != nil && until.After(time.Now()) {
|
||||||
|
// We haven't backed off yet, so wait for the suggested amount of
|
||||||
|
// time.
|
||||||
|
duration := time.Until(*until)
|
||||||
|
log.Warnf("Backing off %q for %s", oq.destination, duration)
|
||||||
|
select {
|
||||||
|
case <-time.After(duration):
|
||||||
|
case <-oq.interruptBackoff:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we have pending PDUs or EDUs then construct a transaction.
|
// If we have pending PDUs or EDUs then construct a transaction.
|
||||||
if pendingPDUs || pendingEDUs {
|
if pendingPDUs || pendingEDUs {
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
|
||||||
server = &ServerStatistics{
|
server = &ServerStatistics{
|
||||||
statistics: s,
|
statistics: s,
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
|
interrupt: make(chan struct{}),
|
||||||
}
|
}
|
||||||
s.servers[serverName] = server
|
s.servers[serverName] = server
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
@ -68,6 +69,7 @@ type ServerStatistics struct {
|
||||||
backoffStarted atomic.Bool // is the backoff started
|
backoffStarted atomic.Bool // is the backoff started
|
||||||
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
||||||
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
|
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
|
||||||
|
interrupt chan struct{} // interrupts the backoff goroutine
|
||||||
successCounter atomic.Uint32 // how many times have we succeeded?
|
successCounter atomic.Uint32 // how many times have we succeeded?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,15 +78,24 @@ func (s *ServerStatistics) duration(count uint32) time.Duration {
|
||||||
return time.Second * time.Duration(math.Exp2(float64(count)))
|
return time.Second * time.Duration(math.Exp2(float64(count)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cancel will interrupt the currently active backoff.
|
||||||
|
func (s *ServerStatistics) cancel() {
|
||||||
|
s.blacklisted.Store(false)
|
||||||
|
s.backoffUntil.Store(time.Time{})
|
||||||
|
select {
|
||||||
|
case s.interrupt <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Success updates the server statistics with a new successful
|
// Success updates the server statistics with a new successful
|
||||||
// attempt, which increases the sent counter and resets the idle and
|
// attempt, which increases the sent counter and resets the idle and
|
||||||
// failure counters. If a host was blacklisted at this point then
|
// failure counters. If a host was blacklisted at this point then
|
||||||
// we will unblacklist it.
|
// we will unblacklist it.
|
||||||
func (s *ServerStatistics) Success() {
|
func (s *ServerStatistics) Success() {
|
||||||
s.successCounter.Add(1)
|
s.cancel()
|
||||||
s.backoffStarted.Store(false)
|
s.successCounter.Inc()
|
||||||
s.backoffCount.Store(0)
|
s.backoffCount.Store(0)
|
||||||
s.blacklisted.Store(false)
|
|
||||||
if s.statistics.DB != nil {
|
if s.statistics.DB != nil {
|
||||||
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
|
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
|
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
|
||||||
|
|
@ -99,10 +110,30 @@ func (s *ServerStatistics) Success() {
|
||||||
// whether we have blacklisted and therefore to give up.
|
// whether we have blacklisted and therefore to give up.
|
||||||
func (s *ServerStatistics) Failure() (time.Time, bool) {
|
func (s *ServerStatistics) Failure() (time.Time, bool) {
|
||||||
// If we aren't already backing off, this call will start
|
// If we aren't already backing off, this call will start
|
||||||
// a new backoff period. Reset the counter to 0 so that
|
// a new backoff period. Increase the failure counter and
|
||||||
// we backoff only for short periods of time to start with.
|
// start a goroutine which will wait out the backoff and
|
||||||
|
// unset the backoffStarted flag when done.
|
||||||
if s.backoffStarted.CAS(false, true) {
|
if s.backoffStarted.CAS(false, true) {
|
||||||
s.backoffCount.Store(0)
|
if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist {
|
||||||
|
s.blacklisted.Store(true)
|
||||||
|
if s.statistics.DB != nil {
|
||||||
|
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return time.Time{}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
until, ok := s.backoffUntil.Load().(time.Time)
|
||||||
|
if ok {
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Until(until)):
|
||||||
|
case <-s.interrupt:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.backoffStarted.Store(false)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we have blacklisted this node.
|
// Check if we have blacklisted this node.
|
||||||
|
|
@ -136,53 +167,6 @@ func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) {
|
||||||
return nil, s.blacklisted.Load()
|
return nil, s.blacklisted.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackoffIfRequired will block for as long as the current
|
|
||||||
// backoff requires, if needed. Otherwise it will do nothing.
|
|
||||||
// Returns the amount of time to backoff for and whether to give up or not.
|
|
||||||
func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <-chan bool) (time.Duration, bool) {
|
|
||||||
if started := s.backoffStarted.Load(); !started {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Work out if we should be blacklisting at this point.
|
|
||||||
count := s.backoffCount.Inc()
|
|
||||||
if count >= s.statistics.FailuresUntilBlacklist {
|
|
||||||
// We've exceeded the maximum amount of times we're willing
|
|
||||||
// to back off, which is probably in the region of hours by
|
|
||||||
// now. Mark the host as blacklisted and tell the caller to
|
|
||||||
// give up.
|
|
||||||
s.blacklisted.Store(true)
|
|
||||||
if s.statistics.DB != nil {
|
|
||||||
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
|
|
||||||
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Work out when we should wait until.
|
|
||||||
duration := s.duration(count)
|
|
||||||
until := time.Now().Add(duration)
|
|
||||||
s.backoffUntil.Store(until)
|
|
||||||
|
|
||||||
// Notify the destination queue that we're backing off now.
|
|
||||||
backingOff.Store(true)
|
|
||||||
defer backingOff.Store(false)
|
|
||||||
|
|
||||||
// Work out how long we should be backing off for.
|
|
||||||
logrus.Warnf("Backing off %q for %s", s.serverName, duration)
|
|
||||||
|
|
||||||
// Wait for either an interruption or for the backoff to
|
|
||||||
// complete.
|
|
||||||
select {
|
|
||||||
case <-interrupt:
|
|
||||||
logrus.Debugf("Interrupting backoff for %q", s.serverName)
|
|
||||||
case <-time.After(duration):
|
|
||||||
}
|
|
||||||
|
|
||||||
return duration, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Blacklisted returns true if the server is blacklisted and false
|
// Blacklisted returns true if the server is blacklisted and false
|
||||||
// otherwise.
|
// otherwise.
|
||||||
func (s *ServerStatistics) Blacklisted() bool {
|
func (s *ServerStatistics) Blacklisted() bool {
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,6 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.uber.org/atomic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBackoff(t *testing.T) {
|
func TestBackoff(t *testing.T) {
|
||||||
|
|
@ -27,34 +25,30 @@ func TestBackoff(t *testing.T) {
|
||||||
server.Failure()
|
server.Failure()
|
||||||
|
|
||||||
t.Logf("Backoff counter: %d", server.backoffCount.Load())
|
t.Logf("Backoff counter: %d", server.backoffCount.Load())
|
||||||
backingOff := atomic.Bool{}
|
|
||||||
|
|
||||||
// Now we're going to simulate backing off a few times to see
|
// Now we're going to simulate backing off a few times to see
|
||||||
// what happens.
|
// what happens.
|
||||||
for i := uint32(1); i <= 10; i++ {
|
for i := uint32(1); i <= 10; i++ {
|
||||||
// Interrupt the backoff - it doesn't really matter if it
|
|
||||||
// completes but we will find out how long the backoff should
|
|
||||||
// have been.
|
|
||||||
interrupt := make(chan bool, 1)
|
|
||||||
close(interrupt)
|
|
||||||
|
|
||||||
// Get the duration.
|
|
||||||
duration, blacklist := server.BackoffIfRequired(backingOff, interrupt)
|
|
||||||
|
|
||||||
// Register another failure for good measure. This should have no
|
// Register another failure for good measure. This should have no
|
||||||
// side effects since a backoff is already in progress. If it does
|
// side effects since a backoff is already in progress. If it does
|
||||||
// then we'll fail.
|
// then we'll fail.
|
||||||
until, blacklisted := server.Failure()
|
until, blacklisted := server.Failure()
|
||||||
if time.Until(until) > duration {
|
|
||||||
t.Fatal("Failure produced unexpected side effect when it shouldn't have")
|
// Get the duration.
|
||||||
}
|
_, blacklist := server.BackoffInfo()
|
||||||
|
duration := time.Until(until).Round(time.Second)
|
||||||
|
|
||||||
|
// Unset the backoff, or otherwise our next call will think that
|
||||||
|
// there's a backoff in progress and return the same result.
|
||||||
|
server.cancel()
|
||||||
|
server.backoffStarted.Store(false)
|
||||||
|
|
||||||
// Check if we should be blacklisted by now.
|
// Check if we should be blacklisted by now.
|
||||||
if i >= stats.FailuresUntilBlacklist {
|
if i >= stats.FailuresUntilBlacklist {
|
||||||
if !blacklist {
|
if !blacklist {
|
||||||
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
|
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
|
||||||
} else if blacklist != blacklisted {
|
} else if blacklist != blacklisted {
|
||||||
t.Fatalf("BackoffIfRequired and Failure returned different blacklist values")
|
t.Fatalf("BackoffInfo and Failure returned different blacklist values")
|
||||||
} else {
|
} else {
|
||||||
t.Logf("Backoff %d is blacklisted as expected", i)
|
t.Logf("Backoff %d is blacklisted as expected", i)
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -24,7 +24,7 @@ require (
|
||||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200925165243-b9780a852681
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v1.14.2
|
github.com/mattn/go-sqlite3 v1.14.2
|
||||||
|
|
|
||||||
4
go.sum
4
go.sum
|
|
@ -569,8 +569,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 h1:43gla6bLt4opWY1mQkAasF/LUCipZl7x2d44TY0wf40=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200925165243-b9780a852681 h1:75fM7vPHiFGt+XxktT17LJD972XMtJ1n7FU1MpC08Zc=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200925165243-b9780a852681/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,9 @@ import (
|
||||||
jaegermetrics "github.com/uber/jaeger-lib/metrics"
|
jaegermetrics "github.com/uber/jaeger-lib/metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// keyIDRegexp defines allowable characters in Key IDs.
|
||||||
|
var keyIDRegexp = regexp.MustCompile("^ed25519:[a-zA-Z0-9_]+$")
|
||||||
|
|
||||||
// Version is the current version of the config format.
|
// Version is the current version of the config format.
|
||||||
// This will change whenever we make breaking changes to the config format.
|
// This will change whenever we make breaking changes to the config format.
|
||||||
const Version = 1
|
const Version = 1
|
||||||
|
|
@ -225,10 +228,30 @@ func loadConfig(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData); err != nil {
|
if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData, true); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, oldPrivateKey := range c.Global.OldVerifyKeys {
|
||||||
|
var oldPrivateKeyData []byte
|
||||||
|
|
||||||
|
oldPrivateKeyPath := absPath(basePath, oldPrivateKey.PrivateKeyPath)
|
||||||
|
oldPrivateKeyData, err = readFile(oldPrivateKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTSPEC: Ordinarily we should enforce key ID formatting, but since there are
|
||||||
|
// a number of private keys out there with non-compatible symbols in them due
|
||||||
|
// to lack of validation in Synapse, we won't enforce that for old verify keys.
|
||||||
|
keyID, privateKey, perr := readKeyPEM(oldPrivateKeyPath, oldPrivateKeyData, false)
|
||||||
|
if perr != nil {
|
||||||
|
return nil, perr
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey
|
||||||
|
}
|
||||||
|
|
||||||
for _, certPath := range c.FederationAPI.FederationCertificatePaths {
|
for _, certPath := range c.FederationAPI.FederationCertificatePaths {
|
||||||
absCertPath := absPath(basePath, certPath)
|
absCertPath := absPath(basePath, certPath)
|
||||||
var pemData []byte
|
var pemData []byte
|
||||||
|
|
@ -441,7 +464,7 @@ func absPath(dir string, path Path) string {
|
||||||
return filepath.Join(dir, string(path))
|
return filepath.Join(dir, string(path))
|
||||||
}
|
}
|
||||||
|
|
||||||
func readKeyPEM(path string, data []byte) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) {
|
func readKeyPEM(path string, data []byte, enforceKeyIDFormat bool) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) {
|
||||||
for {
|
for {
|
||||||
var keyBlock *pem.Block
|
var keyBlock *pem.Block
|
||||||
keyBlock, data = pem.Decode(data)
|
keyBlock, data = pem.Decode(data)
|
||||||
|
|
@ -459,6 +482,9 @@ func readKeyPEM(path string, data []byte) (gomatrixserverlib.KeyID, ed25519.Priv
|
||||||
if !strings.HasPrefix(keyID, "ed25519:") {
|
if !strings.HasPrefix(keyID, "ed25519:") {
|
||||||
return "", nil, fmt.Errorf("key ID %q doesn't start with \"ed25519:\" in %q", keyID, path)
|
return "", nil, fmt.Errorf("key ID %q doesn't start with \"ed25519:\" in %q", keyID, path)
|
||||||
}
|
}
|
||||||
|
if enforceKeyIDFormat && !keyIDRegexp.MatchString(keyID) {
|
||||||
|
return "", nil, fmt.Errorf("key ID %q in %q contains illegal characters (use a-z, A-Z, 0-9 and _ only)", keyID, path)
|
||||||
|
}
|
||||||
_, privKey, err := ed25519.GenerateKey(bytes.NewReader(keyBlock.Bytes))
|
_, privKey, err := ed25519.GenerateKey(bytes.NewReader(keyBlock.Bytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,12 @@ type Global struct {
|
||||||
|
|
||||||
// An arbitrary string used to uniquely identify the PrivateKey. Must start with the
|
// An arbitrary string used to uniquely identify the PrivateKey. Must start with the
|
||||||
// prefix "ed25519:".
|
// prefix "ed25519:".
|
||||||
KeyID gomatrixserverlib.KeyID `yaml:"key_id"`
|
KeyID gomatrixserverlib.KeyID `yaml:"-"`
|
||||||
|
|
||||||
|
// Information about old private keys that used to be used to sign requests and
|
||||||
|
// events on this domain. They will not be used but will be advertised to other
|
||||||
|
// servers that ask for them to help verify old events.
|
||||||
|
OldVerifyKeys []OldVerifyKeys `yaml:"old_private_keys"`
|
||||||
|
|
||||||
// How long a remote server can cache our server key for before requesting it again.
|
// How long a remote server can cache our server key for before requesting it again.
|
||||||
// Increasing this number will reduce the number of requests made by remote servers
|
// Increasing this number will reduce the number of requests made by remote servers
|
||||||
|
|
@ -60,6 +65,21 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
c.Metrics.Verify(configErrs, isMonolith)
|
c.Metrics.Verify(configErrs, isMonolith)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OldVerifyKeys struct {
|
||||||
|
// Path to the private key.
|
||||||
|
PrivateKeyPath Path `yaml:"private_key"`
|
||||||
|
|
||||||
|
// The private key itself.
|
||||||
|
PrivateKey ed25519.PrivateKey `yaml:"-"`
|
||||||
|
|
||||||
|
// The key ID of the private key.
|
||||||
|
KeyID gomatrixserverlib.KeyID `yaml:"-"`
|
||||||
|
|
||||||
|
// When the private key was designed as "expired", as a UNIX timestamp
|
||||||
|
// in millisecond precision.
|
||||||
|
ExpiredAt gomatrixserverlib.Timestamp `yaml:"expired_at"`
|
||||||
|
}
|
||||||
|
|
||||||
// The configuration to use for Prometheus metrics
|
// The configuration to use for Prometheus metrics
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
// Whether or not the metrics are enabled
|
// Whether or not the metrics are enabled
|
||||||
|
|
|
||||||
|
|
@ -234,7 +234,7 @@ func (m mockReadFile) readFile(path string) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadKey(t *testing.T) {
|
func TestReadKey(t *testing.T) {
|
||||||
keyID, _, err := readKeyPEM("path/to/key", []byte(testKey))
|
keyID, _, err := readKeyPEM("path/to/key", []byte(testKey), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to load private key:", err)
|
t.Error("failed to load private key:", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,14 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TxStmtContext behaves similarly to TxStmt, with support for also passing context.
|
||||||
|
func TxStmtContext(context context.Context, transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
|
||||||
|
if transaction != nil {
|
||||||
|
statement = transaction.StmtContext(context, statement)
|
||||||
|
}
|
||||||
|
return statement
|
||||||
|
}
|
||||||
|
|
||||||
// Hack of the century
|
// Hack of the century
|
||||||
func QueryVariadic(count int) string {
|
func QueryVariadic(count int) string {
|
||||||
return QueryVariadicOffset(count, 0)
|
return QueryVariadicOffset(count, 0)
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"math/big"
|
"math/big"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
|
@ -146,10 +147,14 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
|
||||||
err = keyOut.Close()
|
err = keyOut.Close()
|
||||||
})()
|
})()
|
||||||
|
|
||||||
|
keyID := base64.RawURLEncoding.EncodeToString(data[:])
|
||||||
|
keyID = strings.ReplaceAll(keyID, "-", "")
|
||||||
|
keyID = strings.ReplaceAll(keyID, "_", "")
|
||||||
|
|
||||||
err = pem.Encode(keyOut, &pem.Block{
|
err = pem.Encode(keyOut, &pem.Block{
|
||||||
Type: "MATRIX PRIVATE KEY",
|
Type: "MATRIX PRIVATE KEY",
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Key-ID": "ed25519:" + base64.RawStdEncoding.EncodeToString(data[:3]),
|
"Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]),
|
||||||
},
|
},
|
||||||
Bytes: data[3:],
|
Bytes: data[3:],
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
@ -125,7 +126,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
||||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||||
// nullable if there are no results
|
// nullable if there are no results
|
||||||
var nullStream sql.NullInt32
|
var nullStream sql.NullInt32
|
||||||
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
|
|
@ -151,7 +152,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -162,7 +163,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
@ -151,14 +152,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.
|
||||||
}
|
}
|
||||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -180,14 +181,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
) (map[string]json.RawMessage, error) {
|
) (map[string]json.RawMessage, error) {
|
||||||
var keyID string
|
var keyID string
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||||
return map[string]json.RawMessage{
|
return map[string]json.RawMessage{
|
||||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
}, err
|
}, err
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -156,7 +156,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
||||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||||
// nullable if there are no results
|
// nullable if there are no results
|
||||||
var nullStream sql.NullInt32
|
var nullStream sql.NullInt32
|
||||||
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
|
|
@ -188,7 +188,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
@ -153,14 +154,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
||||||
}
|
}
|
||||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -182,14 +183,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
) (map[string]json.RawMessage, error) {
|
) (map[string]json.RawMessage, error) {
|
||||||
var keyID string
|
var keyID string
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,13 @@ type RoomserverInternalAPI interface {
|
||||||
response *QueryMembershipsForRoomResponse,
|
response *QueryMembershipsForRoomResponse,
|
||||||
) error
|
) error
|
||||||
|
|
||||||
|
// Query if we think we're still in a room.
|
||||||
|
QueryServerJoinedToRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
request *QueryServerJoinedToRoomRequest,
|
||||||
|
response *QueryServerJoinedToRoomResponse,
|
||||||
|
) error
|
||||||
|
|
||||||
// Query whether a server is allowed to see an event
|
// Query whether a server is allowed to see an event
|
||||||
QueryServerAllowedToSeeEvent(
|
QueryServerAllowedToSeeEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,16 @@ func (t *RoomserverInternalAPITrace) QueryMembershipsForRoom(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryServerJoinedToRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
req *QueryServerJoinedToRoomRequest,
|
||||||
|
res *QueryServerJoinedToRoomResponse,
|
||||||
|
) error {
|
||||||
|
err := t.Impl.QueryServerJoinedToRoom(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryServerJoinedToRoom req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (t *RoomserverInternalAPITrace) QueryServerAllowedToSeeEvent(
|
func (t *RoomserverInternalAPITrace) QueryServerAllowedToSeeEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *QueryServerAllowedToSeeEventRequest,
|
req *QueryServerAllowedToSeeEventRequest,
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,22 @@ type QueryMembershipsForRoomResponse struct {
|
||||||
HasBeenInRoom bool `json:"has_been_in_room"`
|
HasBeenInRoom bool `json:"has_been_in_room"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryServerJoinedToRoomRequest is a request to QueryServerJoinedToRoom
|
||||||
|
type QueryServerJoinedToRoomRequest struct {
|
||||||
|
// Server name of the server to find
|
||||||
|
ServerName gomatrixserverlib.ServerName `json:"server_name"`
|
||||||
|
// ID of the room to see if we are still joined to
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMembershipsForRoomResponse is a response to QueryServerJoinedToRoom
|
||||||
|
type QueryServerJoinedToRoomResponse struct {
|
||||||
|
// True if the room exists on the server
|
||||||
|
RoomExists bool `json:"room_exists"`
|
||||||
|
// True if we still believe that we are participating in the room
|
||||||
|
IsInRoom bool `json:"is_in_room"`
|
||||||
|
}
|
||||||
|
|
||||||
// QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent
|
// QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent
|
||||||
type QueryServerAllowedToSeeEventRequest struct {
|
type QueryServerAllowedToSeeEventRequest struct {
|
||||||
// The event ID to look up invites in.
|
// The event ID to look up invites in.
|
||||||
|
|
|
||||||
|
|
@ -123,15 +123,7 @@ func SendEventWithRewrite(
|
||||||
// We will handle an event as if it's an outlier if one of the
|
// We will handle an event as if it's an outlier if one of the
|
||||||
// following conditions is true:
|
// following conditions is true:
|
||||||
storeAsOutlier := false
|
storeAsOutlier := false
|
||||||
if authOrStateEvent.Type() == event.Type() && *authOrStateEvent.StateKey() == *event.StateKey() {
|
if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok {
|
||||||
// The event is a state event but the input event is going to
|
|
||||||
// replace it, therefore it can't be added to the state or we'll
|
|
||||||
// get duplicate state keys in the state block. We'll send it
|
|
||||||
// as an outlier because we don't know if something will be
|
|
||||||
// referring to it as an auth event, but need it to be stored
|
|
||||||
// just in case.
|
|
||||||
storeAsOutlier = true
|
|
||||||
} else if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok {
|
|
||||||
// The event is an auth event and isn't a part of the state set.
|
// The event is an auth event and isn't a part of the state set.
|
||||||
// We'll send it as an outlier because we need it to be stored
|
// We'll send it as an outlier because we need it to be stored
|
||||||
// in case something is referring to it as an auth event.
|
// in case something is referring to it as an auth event.
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,78 @@ package helpers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CheckForSoftFail returns true if the event should be soft-failed
|
||||||
|
// and false otherwise. The return error value should be checked before
|
||||||
|
// the soft-fail bool.
|
||||||
|
func CheckForSoftFail(
|
||||||
|
ctx context.Context,
|
||||||
|
db storage.Database,
|
||||||
|
event gomatrixserverlib.HeaderedEvent,
|
||||||
|
stateEventIDs []string,
|
||||||
|
) (bool, error) {
|
||||||
|
rewritesState := len(stateEventIDs) > 1
|
||||||
|
|
||||||
|
var authStateEntries []types.StateEntry
|
||||||
|
var err error
|
||||||
|
if rewritesState {
|
||||||
|
authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Work out if the room exists.
|
||||||
|
var roomInfo *types.RoomInfo
|
||||||
|
roomInfo, err = db.RoomInfo(ctx, event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("db.RoomNID: %w", err)
|
||||||
|
}
|
||||||
|
if roomInfo == nil || roomInfo.IsStub {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then get the state entries for the current state snapshot.
|
||||||
|
// We'll use this to check if the event is allowed right now.
|
||||||
|
roomState := state.NewStateResolution(db, *roomInfo)
|
||||||
|
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// As a special case, it's possible that the room will have no
|
||||||
|
// state because we haven't received a m.room.create event yet.
|
||||||
|
// If we're now processing the first create event then never
|
||||||
|
// soft-fail it.
|
||||||
|
if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work out which of the state events we actually need.
|
||||||
|
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
|
||||||
|
|
||||||
|
// Load the actual auth events from the database.
|
||||||
|
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("loadAuthEvents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the event is allowed.
|
||||||
|
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
|
||||||
|
// return true, nil
|
||||||
|
return true, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CheckAuthEvents checks that the event passes authentication checks
|
// CheckAuthEvents checks that the event passes authentication checks
|
||||||
// Returns the numeric IDs for the auth events.
|
// Returns the numeric IDs for the auth events.
|
||||||
func CheckAuthEvents(
|
func CheckAuthEvents(
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,20 @@ func (r *Inputer) processRoomEvent(
|
||||||
isRejected = true
|
isRejected = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var softfail bool
|
||||||
|
if input.Kind == api.KindBackfill || input.Kind == api.KindNew {
|
||||||
|
// Check that the event passes authentication checks based on the
|
||||||
|
// current room state.
|
||||||
|
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"event_id": event.EventID(),
|
||||||
|
"type": event.Type(),
|
||||||
|
"room": event.RoomID(),
|
||||||
|
}).WithError(err).Info("Error authing soft-failed event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we don't have a transaction ID then get one.
|
// If we don't have a transaction ID then get one.
|
||||||
if input.TransactionID != nil {
|
if input.TransactionID != nil {
|
||||||
tdID := input.TransactionID
|
tdID := input.TransactionID
|
||||||
|
|
@ -88,6 +102,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
"event_id": event.EventID(),
|
"event_id": event.EventID(),
|
||||||
"type": event.Type(),
|
"type": event.Type(),
|
||||||
"room": event.RoomID(),
|
"room": event.RoomID(),
|
||||||
|
"sender": event.Sender(),
|
||||||
}).Debug("Stored outlier")
|
}).Debug("Stored outlier")
|
||||||
return event.EventID(), nil
|
return event.EventID(), nil
|
||||||
}
|
}
|
||||||
|
|
@ -110,11 +125,13 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it.
|
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it.
|
||||||
if isRejected {
|
if isRejected || softfail {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"event_id": event.EventID(),
|
"event_id": event.EventID(),
|
||||||
"type": event.Type(),
|
"type": event.Type(),
|
||||||
"room": event.RoomID(),
|
"room": event.RoomID(),
|
||||||
|
"soft_fail": softfail,
|
||||||
|
"sender": event.Sender(),
|
||||||
}).Debug("Stored rejected event")
|
}).Debug("Stored rejected event")
|
||||||
return event.EventID(), rejectionErr
|
return event.EventID(), rejectionErr
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -227,6 +227,50 @@ func (r *Queryer) QueryMembershipsForRoom(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryServerJoinedToRoom implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryServerJoinedToRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryServerJoinedToRoomRequest,
|
||||||
|
response *api.QueryServerJoinedToRoomResponse,
|
||||||
|
) error {
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||||
|
}
|
||||||
|
if info == nil || info.IsStub {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
response.RoomExists = true
|
||||||
|
|
||||||
|
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
|
||||||
|
}
|
||||||
|
if len(eventNIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := r.DB.Events(ctx, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.Events: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range events {
|
||||||
|
if e.Type() == gomatrixserverlib.MRoomMember && e.StateKey() != nil {
|
||||||
|
_, serverName, err := gomatrixserverlib.SplitID('@', *e.StateKey())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if serverName == request.ServerName {
|
||||||
|
response.IsInRoom = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
|
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
|
||||||
func (r *Queryer) QueryServerAllowedToSeeEvent(
|
func (r *Queryer) QueryServerAllowedToSeeEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ const (
|
||||||
RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID"
|
RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID"
|
||||||
RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser"
|
RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser"
|
||||||
RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom"
|
RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom"
|
||||||
|
RoomserverQueryServerJoinedToRoomPath = "/roomserver/queryServerJoinedToRoomPath"
|
||||||
RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent"
|
RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent"
|
||||||
RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents"
|
RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents"
|
||||||
RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain"
|
RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain"
|
||||||
|
|
@ -326,6 +327,19 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom(
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryMembershipsForRoom implements RoomserverQueryAPI
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryServerJoinedToRoomRequest,
|
||||||
|
response *api.QueryServerJoinedToRoomResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryServerJoinedToRoomPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
|
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
|
||||||
func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent(
|
func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
||||||
|
|
@ -180,6 +180,20 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
RoomserverQueryServerJoinedToRoomPath,
|
||||||
|
httputil.MakeInternalAPI("queryServerJoinedToRoom", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request api.QueryServerJoinedToRoomRequest
|
||||||
|
var response api.QueryServerJoinedToRoomResponse
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
if err := r.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(
|
internalAPIMux.Handle(
|
||||||
RoomserverQueryServerAllowedToSeeEventPath,
|
RoomserverQueryServerAllowedToSeeEventPath,
|
||||||
httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
|
@ -86,7 +87,7 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
for i := range stateBlockNIDs {
|
for i := range stateBlockNIDs {
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
nids[i] = int64(stateBlockNIDs[i])
|
||||||
}
|
}
|
||||||
err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -320,9 +320,14 @@ func (d *Database) Events(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID)
|
if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok {
|
||||||
if err != nil {
|
roomVersion, _ = d.Cache.GetRoomVersion(roomID)
|
||||||
return nil, err
|
}
|
||||||
|
if roomVersion == "" {
|
||||||
|
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON(
|
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON(
|
||||||
eventJSON.EventJSON, false, roomVersion,
|
eventJSON.EventJSON, false, roomVersion,
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
|
@ -25,10 +27,15 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: previous_reference_sha256 was NOT NULL before but it broke sytest because
|
||||||
|
// sytest sends no SHA256 sums in the prev_events references in the soft-fail tests.
|
||||||
|
// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it
|
||||||
|
// seems to care that it's empty and therefore hits a NOT NULL constraint on insert.
|
||||||
|
// We should really work out what the right thing to do here is.
|
||||||
const previousEventSchema = `
|
const previousEventSchema = `
|
||||||
CREATE TABLE IF NOT EXISTS roomserver_previous_events (
|
CREATE TABLE IF NOT EXISTS roomserver_previous_events (
|
||||||
previous_event_id TEXT NOT NULL,
|
previous_event_id TEXT NOT NULL,
|
||||||
previous_reference_sha256 BLOB NOT NULL,
|
previous_reference_sha256 BLOB,
|
||||||
event_nids TEXT NOT NULL,
|
event_nids TEXT NOT NULL,
|
||||||
UNIQUE (previous_event_id, previous_reference_sha256)
|
UNIQUE (previous_event_id, previous_reference_sha256)
|
||||||
);
|
);
|
||||||
|
|
@ -45,6 +52,11 @@ const insertPreviousEventSQL = `
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const selectPreviousEventNIDsSQL = `
|
||||||
|
SELECT event_nids FROM roomserver_previous_events
|
||||||
|
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||||
|
`
|
||||||
|
|
||||||
// Check if the event is referenced by another event in the table.
|
// Check if the event is referenced by another event in the table.
|
||||||
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
|
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
|
||||||
const selectPreviousEventExistsSQL = `
|
const selectPreviousEventExistsSQL = `
|
||||||
|
|
@ -55,6 +67,7 @@ const selectPreviousEventExistsSQL = `
|
||||||
type previousEventStatements struct {
|
type previousEventStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertPreviousEventStmt *sql.Stmt
|
insertPreviousEventStmt *sql.Stmt
|
||||||
|
selectPreviousEventNIDsStmt *sql.Stmt
|
||||||
selectPreviousEventExistsStmt *sql.Stmt
|
selectPreviousEventExistsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -69,6 +82,7 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
|
||||||
|
|
||||||
return s, shared.StatementList{
|
return s, shared.StatementList{
|
||||||
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
|
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
|
||||||
|
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
|
||||||
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
|
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
@ -80,9 +94,28 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
||||||
previousEventReferenceSHA256 []byte,
|
previousEventReferenceSHA256 []byte,
|
||||||
eventNID types.EventNID,
|
eventNID types.EventNID,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
var eventNIDs string
|
||||||
_, err := stmt.ExecContext(
|
eventNIDAsString := fmt.Sprintf("%d", eventNID)
|
||||||
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
|
||||||
|
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
|
||||||
|
}
|
||||||
|
var nids []string
|
||||||
|
if eventNIDs != "" {
|
||||||
|
nids = strings.Split(eventNIDs, ",")
|
||||||
|
for _, nid := range nids {
|
||||||
|
if nid == eventNIDAsString {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
eventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
|
||||||
|
} else {
|
||||||
|
eventNIDs = eventNIDAsString
|
||||||
|
}
|
||||||
|
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||||
|
_, err = insertStmt.ExecContext(
|
||||||
|
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -105,12 +105,12 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
var stateBlockNID types.StateBlockNID
|
var stateBlockNID types.StateBlockNID
|
||||||
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
_, err = txn.Stmt(s.insertStateDataStmt).ExecContext(
|
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
int64(stateBlockNID),
|
int64(stateBlockNID),
|
||||||
int64(entry.EventTypeNID),
|
int64(entry.EventTypeNID),
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
insertStmt := txn.Stmt(s.insertStateStmt)
|
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import (
|
||||||
// DefaultRoomVersion contains the room version that will, by
|
// DefaultRoomVersion contains the room version that will, by
|
||||||
// default, be used to create new rooms on this server.
|
// default, be used to create new rooms on this server.
|
||||||
func DefaultRoomVersion() gomatrixserverlib.RoomVersion {
|
func DefaultRoomVersion() gomatrixserverlib.RoomVersion {
|
||||||
return gomatrixserverlib.RoomVersionV5
|
return gomatrixserverlib.RoomVersionV6
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomVersions returns a map of all known room versions to this
|
// RoomVersions returns a map of all known room versions to this
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ type ServerKeyAPI struct {
|
||||||
ServerKeyValidity time.Duration
|
ServerKeyValidity time.Duration
|
||||||
|
|
||||||
OurKeyRing gomatrixserverlib.KeyRing
|
OurKeyRing gomatrixserverlib.KeyRing
|
||||||
FedClient *gomatrixserverlib.FederationClient
|
FedClient gomatrixserverlib.KeyClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing {
|
func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing {
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.ServerKeyInternalAPI, cach
|
||||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||||
func NewInternalAPI(
|
func NewInternalAPI(
|
||||||
cfg *config.ServerKeyAPI,
|
cfg *config.ServerKeyAPI,
|
||||||
fedClient *gomatrixserverlib.FederationClient,
|
fedClient gomatrixserverlib.KeyClient,
|
||||||
caches *caching.Caches,
|
caches *caching.Caches,
|
||||||
) api.ServerKeyInternalAPI {
|
) api.ServerKeyInternalAPI {
|
||||||
innerDB, err := storage.NewDatabase(
|
innerDB, err := storage.NewDatabase(
|
||||||
|
|
@ -53,7 +53,7 @@ func NewInternalAPI(
|
||||||
OurKeyRing: gomatrixserverlib.KeyRing{
|
OurKeyRing: gomatrixserverlib.KeyRing{
|
||||||
KeyFetchers: []gomatrixserverlib.KeyFetcher{
|
KeyFetchers: []gomatrixserverlib.KeyFetcher{
|
||||||
&gomatrixserverlib.DirectKeyFetcher{
|
&gomatrixserverlib.DirectKeyFetcher{
|
||||||
Client: fedClient.Client,
|
Client: fedClient,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
KeyDatabase: serverKeyDB,
|
KeyDatabase: serverKeyDB,
|
||||||
|
|
@ -65,7 +65,7 @@ func NewInternalAPI(
|
||||||
perspective := &gomatrixserverlib.PerspectiveKeyFetcher{
|
perspective := &gomatrixserverlib.PerspectiveKeyFetcher{
|
||||||
PerspectiveServerName: ps.ServerName,
|
PerspectiveServerName: ps.ServerName,
|
||||||
PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{},
|
PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{},
|
||||||
Client: fedClient.Client,
|
Client: fedClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range ps.Keys {
|
for _, key := range ps.Keys {
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
|
||||||
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
_, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,7 +110,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -160,13 +160,13 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||||
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
|
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
|
_, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
|
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
|
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
@ -85,7 +86,7 @@ func (s *accountDataStatements) InsertAccountData(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
_, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -147,7 +148,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
var nullableID sql.NullInt64
|
var nullableID sql.NullInt64
|
||||||
err = txn.Stmt(s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
|
err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
|
||||||
if nullableID.Valid {
|
if nullableID.Valid {
|
||||||
id = nullableID.Int64
|
id = nullableID.Int64
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
||||||
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
_, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -113,7 +113,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,4 @@ Inbound federation accepts a second soft-failed event
|
||||||
Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
|
Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
|
||||||
|
|
||||||
# We don't implement lazy membership loading yet.
|
# We don't implement lazy membership loading yet.
|
||||||
The only membership state included in a gapped incremental sync is for senders in the timeline
|
The only membership state included in a gapped incremental sync is for senders in the timeline
|
||||||
|
|
||||||
# flakey since implementing rejected events
|
|
||||||
Inbound federation correctly soft fails events
|
|
||||||
|
|
@ -472,4 +472,11 @@ We can't peek into rooms with joined history_visibility
|
||||||
Local users can peek by room alias
|
Local users can peek by room alias
|
||||||
Peeked rooms only turn up in the sync for the device who peeked them
|
Peeked rooms only turn up in the sync for the device who peeked them
|
||||||
Room state at a rejected message event is the same as its predecessor
|
Room state at a rejected message event is the same as its predecessor
|
||||||
Room state at a rejected state event is the same as its predecessor
|
Room state at a rejected state event is the same as its predecessor
|
||||||
|
Inbound federation correctly soft fails events
|
||||||
|
Inbound federation accepts a second soft-failed event
|
||||||
|
Federation key API can act as a notary server via a POST request
|
||||||
|
Federation key API can act as a notary server via a GET request
|
||||||
|
Inbound /make_join rejects attempts to join rooms where all users have left
|
||||||
|
Inbound federation rejects invites which include invalid JSON for room version 6
|
||||||
|
Inbound federation rejects invite rejections which include invalid JSON for room version 6
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
|
|
@ -75,7 +76,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) insertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := txn.Stmt(s.insertAccountDataStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||||
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
|
@ -99,7 +100,7 @@ func (s *accountsStatements) insertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := txn.Stmt(s.insertAccountStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if appserviceID == "" {
|
if appserviceID == "" {
|
||||||
|
|
@ -162,7 +163,7 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
stmt := s.selectNewNumericLocalpartStmt
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = txn.Stmt(stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
|
|
@ -84,7 +85,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *profilesStatements) insertProfile(
|
func (s *profilesStatements) insertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
|
|
@ -75,7 +77,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) insertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
_, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
|
@ -104,9 +105,9 @@ func (s *accountsStatements) insertAccount(
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if appserviceID == "" {
|
if appserviceID == "" {
|
||||||
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||||
} else {
|
} else {
|
||||||
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -163,7 +164,7 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
stmt := s.selectNewNumericLocalpartStmt
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = txn.Stmt(stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *profilesStatements) insertProfile(
|
func (s *profilesStatements) insertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
) error {
|
) error {
|
||||||
_, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue