merge master

This commit is contained in:
Matthew Hodgson 2020-09-26 23:27:39 +01:00
commit 75c3f2df8d
56 changed files with 780 additions and 179 deletions

View file

@ -3,4 +3,4 @@
<!-- Please read CONTRIBUTING.md before submitting your pull request --> <!-- Please read CONTRIBUTING.md before submitting your pull request -->
* [ ] I have added any new tests that need to pass to `testfile` as specified in [docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md) * [ ] I have added any new tests that need to pass to `testfile` as specified in [docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md)
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/CONTRIBUTING.md#sign-off) * [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/docs/CONTRIBUTING.md#sign-off)

View file

@ -38,8 +38,13 @@ global:
# The path to the signing private key file, used to sign requests and events. # The path to the signing private key file, used to sign requests and events.
private_key: matrix_key.pem private_key: matrix_key.pem
# A unique identifier for this private key. Must start with the prefix "ed25519:". # The paths and expiry timestamps (as a UNIX timestamp in millisecond precision)
key_id: ed25519:auto # to old signing private keys that were formerly in use on this domain. These
# keys will not be used for federation request or event signing, but will be
# provided to any other homeserver that asks when trying to verify old events.
# old_private_keys:
# - private_key: old_matrix_key.pem
# expired_at: 1601024554498
# How long a remote server can cache our server signing key before requesting it # How long a remote server can cache our server signing key before requesting it
# again. Increasing this number will reduce the number of requests made by other # again. Increasing this number will reduce the number of requests made by other

View file

@ -39,7 +39,15 @@ func InviteV2(
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
) util.JSONResponse { ) util.JSONResponse {
inviteReq := gomatrixserverlib.InviteV2Request{} inviteReq := gomatrixserverlib.InviteV2Request{}
if err := json.Unmarshal(request.Content(), &inviteReq); err != nil { err := json.Unmarshal(request.Content(), &inviteReq)
switch err.(type) {
case gomatrixserverlib.BadJSONError:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
case nil:
default:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into an invite request. " + err.Error()), JSON: jsonerror.NotJSON("The request body could not be decoded into an invite request. " + err.Error()),
@ -63,10 +71,17 @@ func InviteV1(
roomVer := gomatrixserverlib.RoomVersionV1 roomVer := gomatrixserverlib.RoomVersionV1
body := request.Content() body := request.Content()
event, err := gomatrixserverlib.NewEventFromTrustedJSON(body, false, roomVer) event, err := gomatrixserverlib.NewEventFromTrustedJSON(body, false, roomVer)
if err != nil { switch err.(type) {
case gomatrixserverlib.BadJSONError:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into an invite v1 request: " + err.Error()), JSON: jsonerror.BadJSON(err.Error()),
}
case nil:
default:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into an invite v1 request. " + err.Error()),
} }
} }
var strippedState []gomatrixserverlib.InviteV2StrippedState var strippedState []gomatrixserverlib.InviteV2StrippedState

View file

@ -29,6 +29,7 @@ import (
) )
// MakeJoin implements the /make_join API // MakeJoin implements the /make_join API
// nolint:gocyclo
func MakeJoin( func MakeJoin(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
@ -79,6 +80,29 @@ func MakeJoin(
} }
} }
// Check if we think we are still joined to the room
inRoomReq := &api.QueryServerJoinedToRoomRequest{
ServerName: cfg.Matrix.ServerName,
RoomID: roomID,
}
inRoomRes := &api.QueryServerJoinedToRoomResponse{}
if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), inRoomReq, inRoomRes); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed")
return jsonerror.InternalServerError()
}
if !inRoomRes.RoomExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(fmt.Sprintf("Room ID %q was not found on this server", roomID)),
}
}
if !inRoomRes.IsInRoom {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(fmt.Sprintf("Room ID %q has no remaining users on this server", roomID)),
}
}
// Try building an event for the server // Try building an event for the server
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
Sender: userID, Sender: userID,

View file

@ -19,11 +19,14 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@ -133,6 +136,8 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
var keys gomatrixserverlib.ServerKeys var keys gomatrixserverlib.ServerKeys
keys.ServerName = cfg.Matrix.ServerName keys.ServerName = cfg.Matrix.ServerName
keys.TLSFingerprints = cfg.TLSFingerPrints
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
@ -142,9 +147,15 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
}, },
} }
keys.TLSFingerprints = cfg.TLSFingerPrints
keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{}
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys {
keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: gomatrixserverlib.Base64Bytes(oldVerifyKey.PrivateKey),
},
ExpiredTS: oldVerifyKey.ExpiredAt,
}
}
toSign, err := json.Marshal(keys.ServerKeyFields) toSign, err := json.Marshal(keys.ServerKeyFields)
if err != nil { if err != nil {
@ -160,3 +171,62 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
return &keys, nil return &keys, nil
} }
func NotaryKeys(
httpReq *http.Request, cfg *config.FederationAPI,
fsAPI federationSenderAPI.FederationSenderInternalAPI,
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
) util.JSONResponse {
if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
return *reqErr
}
}
var response struct {
ServerKeys []json.RawMessage `json:"server_keys"`
}
response.ServerKeys = []json.RawMessage{}
for serverName := range req.ServerKeys {
var keys *gomatrixserverlib.ServerKeys
if serverName == cfg.Matrix.ServerName {
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
keys = k
} else {
return util.ErrorResponse(err)
}
} else {
if k, err := fsAPI.GetServerKeys(httpReq.Context(), serverName); err == nil {
keys = &k
} else {
return util.ErrorResponse(err)
}
}
if keys == nil {
continue
}
j, err := json.Marshal(keys)
if err != nil {
logrus.WithError(err).Errorf("Failed to marshal %q response", serverName)
return jsonerror.InternalServerError()
}
js, err := gomatrixserverlib.SignJSON(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, j,
)
if err != nil {
logrus.WithError(err).Errorf("Failed to sign %q response", serverName)
return jsonerror.InternalServerError()
}
response.ServerKeys = append(response.ServerKeys, js)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: response,
}
}

View file

@ -138,7 +138,14 @@ func SendLeave(
// Decode the event JSON from the request. // Decode the event JSON from the request.
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion) event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion)
if err != nil { switch err.(type) {
case gomatrixserverlib.BadJSONError:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
case nil:
default:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()),

View file

@ -61,6 +61,26 @@ func Setup(
return LocalKeys(cfg) return LocalKeys(cfg)
}) })
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
var pkReq *gomatrixserverlib.PublicKeyNotaryLookupRequest
serverName := gomatrixserverlib.ServerName(vars["serverName"])
keyID := gomatrixserverlib.KeyID(vars["keyID"])
if serverName != "" && keyID != "" {
pkReq = &gomatrixserverlib.PublicKeyNotaryLookupRequest{
ServerKeys: map[gomatrixserverlib.ServerName]map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria{
serverName: {
keyID: gomatrixserverlib.PublicKeyNotaryQueryCriteria{},
},
},
}
}
return NotaryKeys(req, cfg, fsAPI, pkReq)
})
// Ignore the {keyID} argument as we only have a single server key so we always // Ignore the {keyID} argument as we only have a single server key so we always
// return that key. // return that key.
// Even if we had more than one server key, we would probably still ignore the // Even if we had more than one server key, we would probably still ignore the
@ -68,6 +88,8 @@ func Setup(
v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet)
v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet)
v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet)
v2keysmux.Handle("/query", notaryKeys).Methods(http.MethodPost)
v2keysmux.Handle("/query/{serverName}/{keyID}", notaryKeys).Methods(http.MethodGet)
v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI( v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI(
"federation_send", cfg.Matrix.ServerName, keys, wakeup, "federation_send", cfg.Matrix.ServerName, keys, wakeup,

View file

@ -199,6 +199,15 @@ func (t *testRoomserverAPI) QueryMembershipsForRoom(
return fmt.Errorf("not implemented") return fmt.Errorf("not implemented")
} }
// Query if a server is joined to a room
func (t *testRoomserverAPI) QueryServerJoinedToRoom(
ctx context.Context,
request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse,
) error {
return fmt.Errorf("not implemented")
}
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
func (t *testRoomserverAPI) QueryServerAllowedToSeeEvent( func (t *testRoomserverAPI) QueryServerAllowedToSeeEvent(
ctx context.Context, ctx context.Context,

View file

@ -20,6 +20,8 @@ type FederationClient interface {
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
} }
// FederationClientError is returned from FederationClient methods in the event of a problem. // FederationClientError is returned from FederationClient methods in the event of a problem.

View file

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/api"
@ -23,6 +24,7 @@ type FederationSenderInternalAPI struct {
federation *gomatrixserverlib.FederationClient federation *gomatrixserverlib.FederationClient
keyRing *gomatrixserverlib.KeyRing keyRing *gomatrixserverlib.KeyRing
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
joins sync.Map // joins currently in progress
} }
func NewFederationSenderInternalAPI( func NewFederationSenderInternalAPI(
@ -187,3 +189,27 @@ func (a *FederationSenderInternalAPI) GetEvent(
} }
return ires.(gomatrixserverlib.Transaction), nil return ires.(gomatrixserverlib.Transaction), nil
} }
func (a *FederationSenderInternalAPI) GetServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName,
) (gomatrixserverlib.ServerKeys, error) {
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.GetServerKeys(ctx, s)
})
if err != nil {
return gomatrixserverlib.ServerKeys{}, err
}
return ires.(gomatrixserverlib.ServerKeys), nil
}
func (a *FederationSenderInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.LookupServerKeys(ctx, s, keyRequests)
})
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
return ires.([]gomatrixserverlib.ServerKeys), nil
}

View file

@ -37,12 +37,32 @@ func (r *FederationSenderInternalAPI) PerformDirectoryLookup(
return nil return nil
} }
type federatedJoin struct {
UserID string
RoomID string
}
// PerformJoinRequest implements api.FederationSenderInternalAPI // PerformJoinRequest implements api.FederationSenderInternalAPI
func (r *FederationSenderInternalAPI) PerformJoin( func (r *FederationSenderInternalAPI) PerformJoin(
ctx context.Context, ctx context.Context,
request *api.PerformJoinRequest, request *api.PerformJoinRequest,
response *api.PerformJoinResponse, response *api.PerformJoinResponse,
) { ) {
// Check that a join isn't already in progress for this user/room.
j := federatedJoin{request.UserID, request.RoomID}
if _, found := r.joins.Load(j); found {
response.LastError = &gomatrix.HTTPError{
Code: 429,
Message: `{
"errcode": "M_LIMIT_EXCEEDED",
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
}`, // TODO: Why do none of our error types play nicely with each other?
}
return
}
r.joins.Store(j, nil)
defer r.joins.Delete(j)
// Look up the supported room versions. // Look up the supported room versions.
var supportedVersions []gomatrixserverlib.RoomVersion var supportedVersions []gomatrixserverlib.RoomVersion
for version := range version.SupportedRoomVersions() { for version := range version.SupportedRoomVersions() {
@ -186,28 +206,47 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
} }
r.statistics.ForServer(serverName).Success() r.statistics.ForServer(serverName).Success()
// Check that the send_join response was valid. // Process the join response in a goroutine. The idea here is
joinCtx := perform.JoinContext(r.federation, r.keyRing) // that we'll try and wait for as long as possible for the work
respState, err := joinCtx.CheckSendJoinResponse( // to complete, but if the client does give up waiting, we'll
ctx, event, serverName, respSendJoin, // still continue to process the join anyway so that we don't
) // waste the effort.
if err != nil { var cancel context.CancelFunc
return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) ctx, cancel = context.WithCancel(context.Background())
} go func() {
defer cancel()
// If we successfully performed a send_join above then the other // Check that the send_join response was valid.
// server now thinks we're a part of the room. Send the newly joinCtx := perform.JoinContext(r.federation, r.keyRing)
// returned state to the roomserver to update our local view. respState, err := joinCtx.CheckSendJoinResponse(
headeredEvent := event.Headered(respMakeJoin.RoomVersion) ctx, event, serverName, respSendJoin,
if err = roomserverAPI.SendEventWithRewrite( )
ctx, r.rsAPI, if err != nil {
respState, logrus.WithFields(logrus.Fields{
headeredEvent, "room_id": roomID,
nil, "user_id": userID,
); err != nil { }).WithError(err).Error("Failed to process room join response")
return fmt.Errorf("r.producer.SendEventWithState: %w", err) return
} }
// If we successfully performed a send_join above then the other
// server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithRewrite(
ctx, r.rsAPI,
respState,
event.Headered(respMakeJoin.RoomVersion),
nil,
); err != nil {
logrus.WithFields(logrus.Fields{
"room_id": roomID,
"user_id": userID,
}).WithError(err).Error("Failed to send room join response to roomserver")
return
}
}()
<-ctx.Done()
return nil return nil
} }

View file

@ -24,13 +24,15 @@ const (
FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive"
FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU" FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU"
FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices" FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices"
FederationSenderClaimKeysPath = "/federationsender/client/claimKeys" FederationSenderClaimKeysPath = "/federationsender/client/claimKeys"
FederationSenderQueryKeysPath = "/federationsender/client/queryKeys" FederationSenderQueryKeysPath = "/federationsender/client/queryKeys"
FederationSenderBackfillPath = "/federationsender/client/backfill" FederationSenderBackfillPath = "/federationsender/client/backfill"
FederationSenderLookupStatePath = "/federationsender/client/lookupState" FederationSenderLookupStatePath = "/federationsender/client/lookupState"
FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs" FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs"
FederationSenderGetEventPath = "/federationsender/client/getEvent" FederationSenderGetEventPath = "/federationsender/client/getEvent"
FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys"
FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys"
) )
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
@ -372,3 +374,59 @@ func (h *httpFederationSenderInternalAPI) GetEvent(
} }
return *response.Res, nil return *response.Res, nil
} }
type getServerKeys struct {
S gomatrixserverlib.ServerName
ServerKeys gomatrixserverlib.ServerKeys
Err *api.FederationClientError
}
func (h *httpFederationSenderInternalAPI) GetServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName,
) (gomatrixserverlib.ServerKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetServerKeys")
defer span.Finish()
request := getServerKeys{
S: s,
}
var response getServerKeys
apiURL := h.federationSenderURL + FederationSenderGetServerKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.ServerKeys{}, err
}
if response.Err != nil {
return gomatrixserverlib.ServerKeys{}, response.Err
}
return response.ServerKeys, nil
}
type lookupServerKeys struct {
S gomatrixserverlib.ServerName
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
ServerKeys []gomatrixserverlib.ServerKeys
Err *api.FederationClientError
}
func (h *httpFederationSenderInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys")
defer span.Finish()
request := lookupServerKeys{
S: s,
KeyRequests: keyRequests,
}
var response lookupServerKeys
apiURL := h.federationSenderURL + FederationSenderLookupServerKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
if response.Err != nil {
return []gomatrixserverlib.ServerKeys{}, response.Err
}
return response.ServerKeys, nil
}

View file

@ -263,4 +263,48 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route
return util.JSONResponse{Code: http.StatusOK, JSON: request} return util.JSONResponse{Code: http.StatusOK, JSON: request}
}), }),
) )
internalAPIMux.Handle(
FederationSenderGetServerKeysPath,
httputil.MakeInternalAPI("GetServerKeys", func(req *http.Request) util.JSONResponse {
var request getServerKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetServerKeys(req.Context(), request.S)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.ServerKeys = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationSenderLookupServerKeysPath,
httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse {
var request lookupServerKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.ServerKeys = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
} }

View file

@ -231,13 +231,24 @@ func (oq *destinationQueue) backgroundSend() {
// If we are backing off this server then wait for the // If we are backing off this server then wait for the
// backoff duration to complete first, or until explicitly // backoff duration to complete first, or until explicitly
// told to retry. // told to retry.
if _, giveUp := oq.statistics.BackoffIfRequired(oq.backingOff, oq.interruptBackoff); giveUp { until, blacklisted := oq.statistics.BackoffInfo()
if blacklisted {
// It's been suggested that we should give up because the backoff // It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory // has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer. // buffers at this point. The PDU clean-up is already on a defer.
log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
return return
} }
if until != nil && until.After(time.Now()) {
// We haven't backed off yet, so wait for the suggested amount of
// time.
duration := time.Until(*until)
log.Warnf("Backing off %q for %s", oq.destination, duration)
select {
case <-time.After(duration):
case <-oq.interruptBackoff:
}
}
// If we have pending PDUs or EDUs then construct a transaction. // If we have pending PDUs or EDUs then construct a transaction.
if pendingPDUs || pendingEDUs { if pendingPDUs || pendingEDUs {

View file

@ -44,6 +44,7 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server = &ServerStatistics{ server = &ServerStatistics{
statistics: s, statistics: s,
serverName: serverName, serverName: serverName,
interrupt: make(chan struct{}),
} }
s.servers[serverName] = server s.servers[serverName] = server
s.mutex.Unlock() s.mutex.Unlock()
@ -68,6 +69,7 @@ type ServerStatistics struct {
backoffStarted atomic.Bool // is the backoff started backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called backoffCount atomic.Uint32 // number of times BackoffDuration has been called
interrupt chan struct{} // interrupts the backoff goroutine
successCounter atomic.Uint32 // how many times have we succeeded? successCounter atomic.Uint32 // how many times have we succeeded?
} }
@ -76,15 +78,24 @@ func (s *ServerStatistics) duration(count uint32) time.Duration {
return time.Second * time.Duration(math.Exp2(float64(count))) return time.Second * time.Duration(math.Exp2(float64(count)))
} }
// cancel will interrupt the currently active backoff.
func (s *ServerStatistics) cancel() {
s.blacklisted.Store(false)
s.backoffUntil.Store(time.Time{})
select {
case s.interrupt <- struct{}{}:
default:
}
}
// Success updates the server statistics with a new successful // Success updates the server statistics with a new successful
// attempt, which increases the sent counter and resets the idle and // attempt, which increases the sent counter and resets the idle and
// failure counters. If a host was blacklisted at this point then // failure counters. If a host was blacklisted at this point then
// we will unblacklist it. // we will unblacklist it.
func (s *ServerStatistics) Success() { func (s *ServerStatistics) Success() {
s.successCounter.Add(1) s.cancel()
s.backoffStarted.Store(false) s.successCounter.Inc()
s.backoffCount.Store(0) s.backoffCount.Store(0)
s.blacklisted.Store(false)
if s.statistics.DB != nil { if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@ -99,10 +110,30 @@ func (s *ServerStatistics) Success() {
// whether we have blacklisted and therefore to give up. // whether we have blacklisted and therefore to give up.
func (s *ServerStatistics) Failure() (time.Time, bool) { func (s *ServerStatistics) Failure() (time.Time, bool) {
// If we aren't already backing off, this call will start // If we aren't already backing off, this call will start
// a new backoff period. Reset the counter to 0 so that // a new backoff period. Increase the failure counter and
// we backoff only for short periods of time to start with. // start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done.
if s.backoffStarted.CAS(false, true) { if s.backoffStarted.CAS(false, true) {
s.backoffCount.Store(0) if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist {
s.blacklisted.Store(true)
if s.statistics.DB != nil {
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
}
}
return time.Time{}, true
}
go func() {
until, ok := s.backoffUntil.Load().(time.Time)
if ok {
select {
case <-time.After(time.Until(until)):
case <-s.interrupt:
}
}
s.backoffStarted.Store(false)
}()
} }
// Check if we have blacklisted this node. // Check if we have blacklisted this node.
@ -136,53 +167,6 @@ func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) {
return nil, s.blacklisted.Load() return nil, s.blacklisted.Load()
} }
// BackoffIfRequired will block for as long as the current
// backoff requires, if needed. Otherwise it will do nothing.
// Returns the amount of time to backoff for and whether to give up or not.
func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <-chan bool) (time.Duration, bool) {
if started := s.backoffStarted.Load(); !started {
return 0, false
}
// Work out if we should be blacklisting at this point.
count := s.backoffCount.Inc()
if count >= s.statistics.FailuresUntilBlacklist {
// We've exceeded the maximum amount of times we're willing
// to back off, which is probably in the region of hours by
// now. Mark the host as blacklisted and tell the caller to
// give up.
s.blacklisted.Store(true)
if s.statistics.DB != nil {
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
}
}
return 0, true
}
// Work out when we should wait until.
duration := s.duration(count)
until := time.Now().Add(duration)
s.backoffUntil.Store(until)
// Notify the destination queue that we're backing off now.
backingOff.Store(true)
defer backingOff.Store(false)
// Work out how long we should be backing off for.
logrus.Warnf("Backing off %q for %s", s.serverName, duration)
// Wait for either an interruption or for the backoff to
// complete.
select {
case <-interrupt:
logrus.Debugf("Interrupting backoff for %q", s.serverName)
case <-time.After(duration):
}
return duration, false
}
// Blacklisted returns true if the server is blacklisted and false // Blacklisted returns true if the server is blacklisted and false
// otherwise. // otherwise.
func (s *ServerStatistics) Blacklisted() bool { func (s *ServerStatistics) Blacklisted() bool {

View file

@ -4,8 +4,6 @@ import (
"math" "math"
"testing" "testing"
"time" "time"
"go.uber.org/atomic"
) )
func TestBackoff(t *testing.T) { func TestBackoff(t *testing.T) {
@ -27,34 +25,30 @@ func TestBackoff(t *testing.T) {
server.Failure() server.Failure()
t.Logf("Backoff counter: %d", server.backoffCount.Load()) t.Logf("Backoff counter: %d", server.backoffCount.Load())
backingOff := atomic.Bool{}
// Now we're going to simulate backing off a few times to see // Now we're going to simulate backing off a few times to see
// what happens. // what happens.
for i := uint32(1); i <= 10; i++ { for i := uint32(1); i <= 10; i++ {
// Interrupt the backoff - it doesn't really matter if it
// completes but we will find out how long the backoff should
// have been.
interrupt := make(chan bool, 1)
close(interrupt)
// Get the duration.
duration, blacklist := server.BackoffIfRequired(backingOff, interrupt)
// Register another failure for good measure. This should have no // Register another failure for good measure. This should have no
// side effects since a backoff is already in progress. If it does // side effects since a backoff is already in progress. If it does
// then we'll fail. // then we'll fail.
until, blacklisted := server.Failure() until, blacklisted := server.Failure()
if time.Until(until) > duration {
t.Fatal("Failure produced unexpected side effect when it shouldn't have") // Get the duration.
} _, blacklist := server.BackoffInfo()
duration := time.Until(until).Round(time.Second)
// Unset the backoff, or otherwise our next call will think that
// there's a backoff in progress and return the same result.
server.cancel()
server.backoffStarted.Store(false)
// Check if we should be blacklisted by now. // Check if we should be blacklisted by now.
if i >= stats.FailuresUntilBlacklist { if i >= stats.FailuresUntilBlacklist {
if !blacklist { if !blacklist {
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
} else if blacklist != blacklisted { } else if blacklist != blacklisted {
t.Fatalf("BackoffIfRequired and Failure returned different blacklist values") t.Fatalf("BackoffInfo and Failure returned different blacklist values")
} else { } else {
t.Logf("Backoff %d is blacklisted as expected", i) t.Logf("Backoff %d is blacklisted as expected", i)
continue continue

2
go.mod
View file

@ -24,7 +24,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 github.com/matrix-org/gomatrixserverlib v0.0.0-20200925165243-b9780a852681
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.2 github.com/mattn/go-sqlite3 v1.14.2

4
go.sum
View file

@ -569,8 +569,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 h1:43gla6bLt4opWY1mQkAasF/LUCipZl7x2d44TY0wf40= github.com/matrix-org/gomatrixserverlib v0.0.0-20200925165243-b9780a852681 h1:75fM7vPHiFGt+XxktT17LJD972XMtJ1n7FU1MpC08Zc=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20200925165243-b9780a852681/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=

View file

@ -36,6 +36,9 @@ import (
jaegermetrics "github.com/uber/jaeger-lib/metrics" jaegermetrics "github.com/uber/jaeger-lib/metrics"
) )
// keyIDRegexp defines allowable characters in Key IDs.
var keyIDRegexp = regexp.MustCompile("^ed25519:[a-zA-Z0-9_]+$")
// Version is the current version of the config format. // Version is the current version of the config format.
// This will change whenever we make breaking changes to the config format. // This will change whenever we make breaking changes to the config format.
const Version = 1 const Version = 1
@ -225,10 +228,30 @@ func loadConfig(
return nil, err return nil, err
} }
if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData); err != nil { if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData, true); err != nil {
return nil, err return nil, err
} }
for i, oldPrivateKey := range c.Global.OldVerifyKeys {
var oldPrivateKeyData []byte
oldPrivateKeyPath := absPath(basePath, oldPrivateKey.PrivateKeyPath)
oldPrivateKeyData, err = readFile(oldPrivateKeyPath)
if err != nil {
return nil, err
}
// NOTSPEC: Ordinarily we should enforce key ID formatting, but since there are
// a number of private keys out there with non-compatible symbols in them due
// to lack of validation in Synapse, we won't enforce that for old verify keys.
keyID, privateKey, perr := readKeyPEM(oldPrivateKeyPath, oldPrivateKeyData, false)
if perr != nil {
return nil, perr
}
c.Global.OldVerifyKeys[i].KeyID, c.Global.OldVerifyKeys[i].PrivateKey = keyID, privateKey
}
for _, certPath := range c.FederationAPI.FederationCertificatePaths { for _, certPath := range c.FederationAPI.FederationCertificatePaths {
absCertPath := absPath(basePath, certPath) absCertPath := absPath(basePath, certPath)
var pemData []byte var pemData []byte
@ -441,7 +464,7 @@ func absPath(dir string, path Path) string {
return filepath.Join(dir, string(path)) return filepath.Join(dir, string(path))
} }
func readKeyPEM(path string, data []byte) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) { func readKeyPEM(path string, data []byte, enforceKeyIDFormat bool) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) {
for { for {
var keyBlock *pem.Block var keyBlock *pem.Block
keyBlock, data = pem.Decode(data) keyBlock, data = pem.Decode(data)
@ -459,6 +482,9 @@ func readKeyPEM(path string, data []byte) (gomatrixserverlib.KeyID, ed25519.Priv
if !strings.HasPrefix(keyID, "ed25519:") { if !strings.HasPrefix(keyID, "ed25519:") {
return "", nil, fmt.Errorf("key ID %q doesn't start with \"ed25519:\" in %q", keyID, path) return "", nil, fmt.Errorf("key ID %q doesn't start with \"ed25519:\" in %q", keyID, path)
} }
if enforceKeyIDFormat && !keyIDRegexp.MatchString(keyID) {
return "", nil, fmt.Errorf("key ID %q in %q contains illegal characters (use a-z, A-Z, 0-9 and _ only)", keyID, path)
}
_, privKey, err := ed25519.GenerateKey(bytes.NewReader(keyBlock.Bytes)) _, privKey, err := ed25519.GenerateKey(bytes.NewReader(keyBlock.Bytes))
if err != nil { if err != nil {
return "", nil, err return "", nil, err

View file

@ -20,7 +20,12 @@ type Global struct {
// An arbitrary string used to uniquely identify the PrivateKey. Must start with the // An arbitrary string used to uniquely identify the PrivateKey. Must start with the
// prefix "ed25519:". // prefix "ed25519:".
KeyID gomatrixserverlib.KeyID `yaml:"key_id"` KeyID gomatrixserverlib.KeyID `yaml:"-"`
// Information about old private keys that used to be used to sign requests and
// events on this domain. They will not be used but will be advertised to other
// servers that ask for them to help verify old events.
OldVerifyKeys []OldVerifyKeys `yaml:"old_private_keys"`
// How long a remote server can cache our server key for before requesting it again. // How long a remote server can cache our server key for before requesting it again.
// Increasing this number will reduce the number of requests made by remote servers // Increasing this number will reduce the number of requests made by remote servers
@ -60,6 +65,21 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.Metrics.Verify(configErrs, isMonolith) c.Metrics.Verify(configErrs, isMonolith)
} }
type OldVerifyKeys struct {
// Path to the private key.
PrivateKeyPath Path `yaml:"private_key"`
// The private key itself.
PrivateKey ed25519.PrivateKey `yaml:"-"`
// The key ID of the private key.
KeyID gomatrixserverlib.KeyID `yaml:"-"`
// When the private key was designed as "expired", as a UNIX timestamp
// in millisecond precision.
ExpiredAt gomatrixserverlib.Timestamp `yaml:"expired_at"`
}
// The configuration to use for Prometheus metrics // The configuration to use for Prometheus metrics
type Metrics struct { type Metrics struct {
// Whether or not the metrics are enabled // Whether or not the metrics are enabled

View file

@ -234,7 +234,7 @@ func (m mockReadFile) readFile(path string) ([]byte, error) {
} }
func TestReadKey(t *testing.T) { func TestReadKey(t *testing.T) {
keyID, _, err := readKeyPEM("path/to/key", []byte(testKey)) keyID, _, err := readKeyPEM("path/to/key", []byte(testKey), true)
if err != nil { if err != nil {
t.Error("failed to load private key:", err) t.Error("failed to load private key:", err)
} }

View file

@ -88,6 +88,14 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
return statement return statement
} }
// TxStmtContext behaves similarly to TxStmt, with support for also passing context.
func TxStmtContext(context context.Context, transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
if transaction != nil {
statement = transaction.StmtContext(context, statement)
}
return statement
}
// Hack of the century // Hack of the century
func QueryVariadic(count int) string { func QueryVariadic(count int) string {
return QueryVariadicOffset(count, 0) return QueryVariadicOffset(count, 0)

View file

@ -25,6 +25,7 @@ import (
"math/big" "math/big"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
@ -146,10 +147,14 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
err = keyOut.Close() err = keyOut.Close()
})() })()
keyID := base64.RawURLEncoding.EncodeToString(data[:])
keyID = strings.ReplaceAll(keyID, "-", "")
keyID = strings.ReplaceAll(keyID, "_", "")
err = pem.Encode(keyOut, &pem.Block{ err = pem.Encode(keyOut, &pem.Block{
Type: "MATRIX PRIVATE KEY", Type: "MATRIX PRIVATE KEY",
Headers: map[string]string{ Headers: map[string]string{
"Key-ID": "ed25519:" + base64.RawStdEncoding.EncodeToString(data[:3]), "Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]),
}, },
Bytes: data[3:], Bytes: data[3:],
}) })

View file

@ -21,6 +21,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -125,7 +126,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt32
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
@ -151,7 +152,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
) )
if err != nil { if err != nil {
@ -162,7 +163,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err return err
} }

View file

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -151,14 +152,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.
} }
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -180,14 +181,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{ return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON), algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err }, err

View file

@ -97,7 +97,7 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err return err
} }
@ -156,7 +156,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt32
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
@ -188,7 +188,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
) )
if err != nil { if err != nil {

View file

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -153,14 +154,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
} }
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -182,14 +183,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -95,6 +95,13 @@ type RoomserverInternalAPI interface {
response *QueryMembershipsForRoomResponse, response *QueryMembershipsForRoomResponse,
) error ) error
// Query if we think we're still in a room.
QueryServerJoinedToRoom(
ctx context.Context,
request *QueryServerJoinedToRoomRequest,
response *QueryServerJoinedToRoomResponse,
) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent( QueryServerAllowedToSeeEvent(
ctx context.Context, ctx context.Context,

View file

@ -144,6 +144,16 @@ func (t *RoomserverInternalAPITrace) QueryMembershipsForRoom(
return err return err
} }
func (t *RoomserverInternalAPITrace) QueryServerJoinedToRoom(
ctx context.Context,
req *QueryServerJoinedToRoomRequest,
res *QueryServerJoinedToRoomResponse,
) error {
err := t.Impl.QueryServerJoinedToRoom(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("QueryServerJoinedToRoom req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) QueryServerAllowedToSeeEvent( func (t *RoomserverInternalAPITrace) QueryServerAllowedToSeeEvent(
ctx context.Context, ctx context.Context,
req *QueryServerAllowedToSeeEventRequest, req *QueryServerAllowedToSeeEventRequest,

View file

@ -140,6 +140,22 @@ type QueryMembershipsForRoomResponse struct {
HasBeenInRoom bool `json:"has_been_in_room"` HasBeenInRoom bool `json:"has_been_in_room"`
} }
// QueryServerJoinedToRoomRequest is a request to QueryServerJoinedToRoom
type QueryServerJoinedToRoomRequest struct {
// Server name of the server to find
ServerName gomatrixserverlib.ServerName `json:"server_name"`
// ID of the room to see if we are still joined to
RoomID string `json:"room_id"`
}
// QueryMembershipsForRoomResponse is a response to QueryServerJoinedToRoom
type QueryServerJoinedToRoomResponse struct {
// True if the room exists on the server
RoomExists bool `json:"room_exists"`
// True if we still believe that we are participating in the room
IsInRoom bool `json:"is_in_room"`
}
// QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent // QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent
type QueryServerAllowedToSeeEventRequest struct { type QueryServerAllowedToSeeEventRequest struct {
// The event ID to look up invites in. // The event ID to look up invites in.

View file

@ -123,15 +123,7 @@ func SendEventWithRewrite(
// We will handle an event as if it's an outlier if one of the // We will handle an event as if it's an outlier if one of the
// following conditions is true: // following conditions is true:
storeAsOutlier := false storeAsOutlier := false
if authOrStateEvent.Type() == event.Type() && *authOrStateEvent.StateKey() == *event.StateKey() { if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok {
// The event is a state event but the input event is going to
// replace it, therefore it can't be added to the state or we'll
// get duplicate state keys in the state block. We'll send it
// as an outlier because we don't know if something will be
// referring to it as an auth event, but need it to be stored
// just in case.
storeAsOutlier = true
} else if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok {
// The event is an auth event and isn't a part of the state set. // The event is an auth event and isn't a part of the state set.
// We'll send it as an outlier because we need it to be stored // We'll send it as an outlier because we need it to be stored
// in case something is referring to it as an auth event. // in case something is referring to it as an auth event.

View file

@ -16,13 +16,78 @@ package helpers
import ( import (
"context" "context"
"fmt"
"sort" "sort"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// CheckForSoftFail returns true if the event should be soft-failed
// and false otherwise. The return error value should be checked before
// the soft-fail bool.
func CheckForSoftFail(
ctx context.Context,
db storage.Database,
event gomatrixserverlib.HeaderedEvent,
stateEventIDs []string,
) (bool, error) {
rewritesState := len(stateEventIDs) > 1
var authStateEntries []types.StateEntry
var err error
if rewritesState {
authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs)
if err != nil {
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
}
} else {
// Work out if the room exists.
var roomInfo *types.RoomInfo
roomInfo, err = db.RoomInfo(ctx, event.RoomID())
if err != nil {
return false, fmt.Errorf("db.RoomNID: %w", err)
}
if roomInfo == nil || roomInfo.IsStub {
return false, nil
}
// Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db, *roomInfo)
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
if err != nil {
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
}
}
// As a special case, it's possible that the room will have no
// state because we haven't received a m.room.create event yet.
// If we're now processing the first create event then never
// soft-fail it.
if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate {
return false, nil
}
// Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err)
}
// Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
// return true, nil
return true, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
}
return false, nil
}
// CheckAuthEvents checks that the event passes authentication checks // CheckAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events. // Returns the numeric IDs for the auth events.
func CheckAuthEvents( func CheckAuthEvents(

View file

@ -53,6 +53,20 @@ func (r *Inputer) processRoomEvent(
isRejected = true isRejected = true
} }
var softfail bool
if input.Kind == api.KindBackfill || input.Kind == api.KindNew {
// Check that the event passes authentication checks based on the
// current room state.
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": event.EventID(),
"type": event.Type(),
"room": event.RoomID(),
}).WithError(err).Info("Error authing soft-failed event")
}
}
// If we don't have a transaction ID then get one. // If we don't have a transaction ID then get one.
if input.TransactionID != nil { if input.TransactionID != nil {
tdID := input.TransactionID tdID := input.TransactionID
@ -88,6 +102,7 @@ func (r *Inputer) processRoomEvent(
"event_id": event.EventID(), "event_id": event.EventID(),
"type": event.Type(), "type": event.Type(),
"room": event.RoomID(), "room": event.RoomID(),
"sender": event.Sender(),
}).Debug("Stored outlier") }).Debug("Stored outlier")
return event.EventID(), nil return event.EventID(), nil
} }
@ -110,11 +125,13 @@ func (r *Inputer) processRoomEvent(
} }
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it.
if isRejected { if isRejected || softfail {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
"type": event.Type(), "type": event.Type(),
"room": event.RoomID(), "room": event.RoomID(),
"soft_fail": softfail,
"sender": event.Sender(),
}).Debug("Stored rejected event") }).Debug("Stored rejected event")
return event.EventID(), rejectionErr return event.EventID(), rejectionErr
} }

View file

@ -227,6 +227,50 @@ func (r *Queryer) QueryMembershipsForRoom(
return nil return nil
} }
// QueryServerJoinedToRoom implements api.RoomserverInternalAPI
func (r *Queryer) QueryServerJoinedToRoom(
ctx context.Context,
request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse,
) error {
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err)
}
if info == nil || info.IsStub {
return nil
}
response.RoomExists = true
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
if err != nil {
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
}
if len(eventNIDs) == 0 {
return nil
}
events, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, e := range events {
if e.Type() == gomatrixserverlib.MRoomMember && e.StateKey() != nil {
_, serverName, err := gomatrixserverlib.SplitID('@', *e.StateKey())
if err != nil {
continue
}
if serverName == request.ServerName {
response.IsInRoom = true
break
}
}
}
return nil
}
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *Queryer) QueryServerAllowedToSeeEvent( func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context, ctx context.Context,

View file

@ -39,6 +39,7 @@ const (
RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID"
RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser"
RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom"
RoomserverQueryServerJoinedToRoomPath = "/roomserver/queryServerJoinedToRoomPath"
RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent" RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent"
RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents" RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents"
RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain" RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain"
@ -326,6 +327,19 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryMembershipsForRoom implements RoomserverQueryAPI
func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom(
ctx context.Context,
request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryServerJoinedToRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI // QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent( func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent(
ctx context.Context, ctx context.Context,

View file

@ -180,6 +180,20 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(
RoomserverQueryServerJoinedToRoomPath,
httputil.MakeInternalAPI("queryServerJoinedToRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryServerJoinedToRoomRequest
var response api.QueryServerJoinedToRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryServerAllowedToSeeEventPath, RoomserverQueryServerAllowedToSeeEventPath,
httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {

View file

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -86,7 +87,7 @@ func (s *stateSnapshotStatements) InsertState(
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i]) nids[i] = int64(stateBlockNIDs[i])
} }
err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
return return
} }

View file

@ -320,9 +320,14 @@ func (d *Database) Events(
if err != nil { if err != nil {
return nil, err return nil, err
} }
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok {
if err != nil { roomVersion, _ = d.Cache.GetRoomVersion(roomID)
return nil, err }
if roomVersion == "" {
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return nil, err
}
} }
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON(
eventJSON.EventJSON, false, roomVersion, eventJSON.EventJSON, false, roomVersion,

View file

@ -18,6 +18,8 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -25,10 +27,15 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
// TODO: previous_reference_sha256 was NOT NULL before but it broke sytest because
// sytest sends no SHA256 sums in the prev_events references in the soft-fail tests.
// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it
// seems to care that it's empty and therefore hits a NOT NULL constraint on insert.
// We should really work out what the right thing to do here is.
const previousEventSchema = ` const previousEventSchema = `
CREATE TABLE IF NOT EXISTS roomserver_previous_events ( CREATE TABLE IF NOT EXISTS roomserver_previous_events (
previous_event_id TEXT NOT NULL, previous_event_id TEXT NOT NULL,
previous_reference_sha256 BLOB NOT NULL, previous_reference_sha256 BLOB,
event_nids TEXT NOT NULL, event_nids TEXT NOT NULL,
UNIQUE (previous_event_id, previous_reference_sha256) UNIQUE (previous_event_id, previous_reference_sha256)
); );
@ -45,6 +52,11 @@ const insertPreviousEventSQL = `
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
` `
const selectPreviousEventNIDsSQL = `
SELECT event_nids FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
`
// Check if the event is referenced by another event in the table. // Check if the event is referenced by another event in the table.
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
const selectPreviousEventExistsSQL = ` const selectPreviousEventExistsSQL = `
@ -55,6 +67,7 @@ const selectPreviousEventExistsSQL = `
type previousEventStatements struct { type previousEventStatements struct {
db *sql.DB db *sql.DB
insertPreviousEventStmt *sql.Stmt insertPreviousEventStmt *sql.Stmt
selectPreviousEventNIDsStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt
} }
@ -69,6 +82,7 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
return s, shared.StatementList{ return s, shared.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -80,9 +94,28 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte, previousEventReferenceSHA256 []byte,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) var eventNIDs string
_, err := stmt.ExecContext( eventNIDAsString := fmt.Sprintf("%d", eventNID)
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
if err != sql.ErrNoRows {
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
}
var nids []string
if eventNIDs != "" {
nids = strings.Split(eventNIDs, ",")
for _, nid := range nids {
if nid == eventNIDAsString {
return nil
}
}
eventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
} else {
eventNIDs = eventNIDAsString
}
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err = insertStmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
) )
return err return err
} }

View file

@ -105,12 +105,12 @@ func (s *stateBlockStatements) BulkInsertStateData(
return 0, nil return 0, nil
} }
var stateBlockNID types.StateBlockNID var stateBlockNID types.StateBlockNID
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
if err != nil { if err != nil {
return 0, err return 0, err
} }
for _, entry := range entries { for _, entry := range entries {
_, err = txn.Stmt(s.insertStateDataStmt).ExecContext( _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
ctx, ctx,
int64(stateBlockNID), int64(stateBlockNID),
int64(entry.EventTypeNID), int64(entry.EventTypeNID),

View file

@ -76,7 +76,7 @@ func (s *stateSnapshotStatements) InsertState(
if err != nil { if err != nil {
return return
} }
insertStmt := txn.Stmt(s.insertStateStmt) insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
if err != nil { if err != nil {
return 0, err return 0, err

View file

@ -23,7 +23,7 @@ import (
// DefaultRoomVersion contains the room version that will, by // DefaultRoomVersion contains the room version that will, by
// default, be used to create new rooms on this server. // default, be used to create new rooms on this server.
func DefaultRoomVersion() gomatrixserverlib.RoomVersion { func DefaultRoomVersion() gomatrixserverlib.RoomVersion {
return gomatrixserverlib.RoomVersionV5 return gomatrixserverlib.RoomVersionV6
} }
// RoomVersions returns a map of all known room versions to this // RoomVersions returns a map of all known room versions to this

View file

@ -20,7 +20,7 @@ type ServerKeyAPI struct {
ServerKeyValidity time.Duration ServerKeyValidity time.Duration
OurKeyRing gomatrixserverlib.KeyRing OurKeyRing gomatrixserverlib.KeyRing
FedClient *gomatrixserverlib.FederationClient FedClient gomatrixserverlib.KeyClient
} }
func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing { func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing {

View file

@ -26,7 +26,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.ServerKeyInternalAPI, cach
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
cfg *config.ServerKeyAPI, cfg *config.ServerKeyAPI,
fedClient *gomatrixserverlib.FederationClient, fedClient gomatrixserverlib.KeyClient,
caches *caching.Caches, caches *caching.Caches,
) api.ServerKeyInternalAPI { ) api.ServerKeyInternalAPI {
innerDB, err := storage.NewDatabase( innerDB, err := storage.NewDatabase(
@ -53,7 +53,7 @@ func NewInternalAPI(
OurKeyRing: gomatrixserverlib.KeyRing{ OurKeyRing: gomatrixserverlib.KeyRing{
KeyFetchers: []gomatrixserverlib.KeyFetcher{ KeyFetchers: []gomatrixserverlib.KeyFetcher{
&gomatrixserverlib.DirectKeyFetcher{ &gomatrixserverlib.DirectKeyFetcher{
Client: fedClient.Client, Client: fedClient,
}, },
}, },
KeyDatabase: serverKeyDB, KeyDatabase: serverKeyDB,
@ -65,7 +65,7 @@ func NewInternalAPI(
perspective := &gomatrixserverlib.PerspectiveKeyFetcher{ perspective := &gomatrixserverlib.PerspectiveKeyFetcher{
PerspectiveServerName: ps.ServerName, PerspectiveServerName: ps.ServerName,
PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{}, PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{},
Client: fedClient.Client, Client: fedClient,
} }
for _, key := range ps.Keys { for _, key := range ps.Keys {

View file

@ -81,7 +81,7 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return return
} }
@ -110,7 +110,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return return
} }

View file

@ -160,13 +160,13 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) _, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
return return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
return return
} }

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -85,7 +86,7 @@ func (s *accountDataStatements) InsertAccountData(
if err != nil { if err != nil {
return return
} }
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
return return
} }
@ -147,7 +148,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
var nullableID sql.NullInt64 var nullableID sql.NullInt64
err = txn.Stmt(s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid { if nullableID.Valid {
id = nullableID.Int64 id = nullableID.Int64
} }

View file

@ -84,7 +84,7 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err return err
} }
@ -113,7 +113,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err return err
} }

View file

@ -52,7 +52,4 @@ Inbound federation accepts a second soft-failed event
Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
# We don't implement lazy membership loading yet. # We don't implement lazy membership loading yet.
The only membership state included in a gapped incremental sync is for senders in the timeline The only membership state included in a gapped incremental sync is for senders in the timeline
# flakey since implementing rejected events
Inbound federation correctly soft fails events

View file

@ -472,4 +472,11 @@ We can't peek into rooms with joined history_visibility
Local users can peek by room alias Local users can peek by room alias
Peeked rooms only turn up in the sync for the device who peeked them Peeked rooms only turn up in the sync for the device who peeked them
Room state at a rejected message event is the same as its predecessor Room state at a rejected message event is the same as its predecessor
Room state at a rejected state event is the same as its predecessor Room state at a rejected state event is the same as its predecessor
Inbound federation correctly soft fails events
Inbound federation accepts a second soft-failed event
Federation key API can act as a notary server via a POST request
Federation key API can act as a notary server via a GET request
Inbound /make_join rejects attempts to join rooms where all users have left
Inbound federation rejects invites which include invalid JSON for room version 6
Inbound federation rejects invite rejections which include invalid JSON for room version 6

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -75,7 +76,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := txn.Stmt(s.insertAccountDataStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return return
} }

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -99,7 +100,7 @@ func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := txn.Stmt(s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if appserviceID == "" {
@ -162,7 +163,7 @@ func (s *accountsStatements) selectNewNumericLocalpart(
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt
if txn != nil { if txn != nil {
stmt = txn.Stmt(stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return return

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const profilesSchema = ` const profilesSchema = `
@ -84,7 +85,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -75,7 +77,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
_, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return err return err
} }

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -104,9 +105,9 @@ func (s *accountsStatements) insertAccount(
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else { } else {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -163,7 +164,7 @@ func (s *accountsStatements) selectNewNumericLocalpart(
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt
if txn != nil { if txn != nil {
stmt = txn.Stmt(stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return return

View file

@ -87,7 +87,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) error { ) error {
_, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return err return err
} }